diff --git a/src/MerkleGen.sol b/src/MerkleGen.sol index 8bbcbc1..c716e18 100644 --- a/src/MerkleGen.sol +++ b/src/MerkleGen.sol @@ -6,6 +6,7 @@ import {ArrayLib} from "./libraries/ArrayLib.sol"; /** * @notice Library for generating Merkle MultiProofs. * @author sonicskye. + * @author kamuikatsurgi. */ library MerkleGen { using ArrayLib for *; @@ -14,12 +15,109 @@ library MerkleGen { bool private constant SOURCE_FROM_PROOF = false; /** - * @notice Hashes two internal nodes to generate their parent node. + * @notice Generates a Merkle MultiProof for the selected leaves. + * @dev Constructs the necessary proof components and verifies the Merkle root. + * @dev The computed root must match the actual root of the Merkle tree. + * @param hashed_leaves The array of hashed leaves in the Merkle tree. + * @param selected_indexes The indices of the leaves to include in the proof. + * @return Sibling hashes required for the proof. + * @return Flags indicating the source of each proof hash. + * @return Merkle root of the tree. + */ + function generateMultiproof(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes) + public + pure + returns (bytes32[] memory, bool[] memory, bytes32) + { + bytes32[] memory layer = hashed_leaves.copy(); + // Append with the same leaf if odd number of leaves + if (layer.length % 2 == 1) { + layer = layer.append(layer[layer.length - 1]); + } + // Create a two dimensional array + bytes32[][] memory layers = new bytes32[][](1); + layers[0] = layer; + bytes32[] memory parent_layer; + while (layer.length > 1) { + parent_layer = _computeParentLayer(layer); + layers = layers.append(parent_layer); + layer = parent_layer; + } + + bytes32[] memory proof_hashes; + bool[] memory proof_source_flags; + uint256[] memory indices = selected_indexes.copy(); + + bytes32[] memory subproof; + bool[] memory source_flags; + for (uint256 i = 0; i < layers.length - 1; i++) { + // Exclude the last layer because it's the root + layer = layers[i]; + (indices, subproof, source_flags) = _proveSingleLayer(layer, indices); + proof_hashes = proof_hashes.extend(subproof); + proof_source_flags = proof_source_flags.extend(source_flags); + } + + // Get leaves in hashed_leaves that are in selected_indexes + bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length); + for (uint256 i = 0; i < selected_indexes.length; i++) { + indexed_leaves[i] = hashed_leaves[selected_indexes[i]]; + } + + bytes32 root = _verifyComputeRoot(indexed_leaves, proof_hashes, proof_source_flags); + + // Check if computed root is the same as the root of the tree + require(root == layers[layers.length - 1][0], "Invalid root"); + + // Convert proof_source_flags to bits and uint256 + uint256 proof_flag_bits = 0; + bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length); + for (uint256 i = 0; i < proof_source_flags.length; i++) { + if (proof_source_flags[i] == SOURCE_FROM_HASHES) { + proof_flag_bits_bool[i] = true; + proof_flag_bits = proof_flag_bits | (1 << i); + } else { + proof_flag_bits_bool[i] = false; + proof_flag_bits = proof_flag_bits | (0 << i); + } + } + + return (proof_hashes, proof_flag_bits_bool, root); + } + + /** + * @notice Generates a Merkle proof for a single leaf in the Merkle tree. + * @dev The function computes the proof and the root of the Merkle tree. + * @param leaves The array of leaves used to build the Merkle tree. + * @param leafIndex The index of the leaf for which the proof is generated. + * @return proof An array of sibling hashes forming the Merkle proof for the leaf. + * @return root The root hash of the Merkle tree. + */ + function generateSingleProof(bytes32[] memory leaves, uint256 leafIndex) + public + pure + returns (bytes32[] memory, bytes32) + { + require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1."); + + // Append with the same leaf if odd number of leaves + if (leaves.length % 2 == 1) { + leaves = leaves.append(leaves[leaves.length - 1]); + } + + bytes32[] memory proof = _getProof(leaves, leafIndex); + bytes32 root = _getRoot(leaves); + + return (proof, root); + } + + /** + * @notice Hashes two leaf nodes to generate their parent node. * @param a First child node. * @param b Second child node. * @return h Hashed parent node. */ - function hash_internal_nodes(bytes32 a, bytes32 b) internal pure returns (bytes32 h) { + function _hashLeafPairs(bytes32 a, bytes32 b) internal pure returns (bytes32 h) { if (a < b) { h = keccak256(abi.encodePacked(a, b)); } else { @@ -28,12 +126,12 @@ library MerkleGen { } /** - * @notice Computes the next layer in the Merkle tree from the current layer. + * @notice Computes the parent layer in the Merkle tree from the current layer. * @dev If the current layer has an odd number of nodes, the last node is duplicated. * @param layer Current layer of the Merkle tree. - * @return Computed next layer. + * @return Computed parent layer. */ - function compute_next_layer(bytes32[] memory layer) internal pure returns (bytes32[] memory) { + function _computeParentLayer(bytes32[] memory layer) internal pure returns (bytes32[] memory) { if (layer.length == 1) { return layer; } @@ -43,13 +141,13 @@ library MerkleGen { layer = layer.append(layer[layer.length - 1]); } - bytes32[] memory next_layer; + bytes32[] memory parent_layer; for (uint256 i = 0; i < layer.length; i += 2) { - next_layer = next_layer.append(hash_internal_nodes(layer[i], layer[i + 1])); + parent_layer = parent_layer.append(_hashLeafPairs(layer[i], layer[i + 1])); } - return next_layer; + return parent_layer; } /** @@ -57,16 +155,16 @@ library MerkleGen { * @param index Current node index. * @return Parent node index. */ - function parent_index(uint256 index) internal pure returns (uint256) { + function _getParentIndex(uint256 index) internal pure returns (uint256) { return index / 2; } /** - * @notice Determines the sibling index of a given node index. + * @notice Calculates the sibling index of a given node index. * @param index Current node index. * @return Sibling node index. */ - function sibling_index(uint256 index) internal pure returns (uint256) { + function _getSiblingIndex(uint256 index) internal pure returns (uint256) { return index ^ 1; } @@ -79,7 +177,7 @@ library MerkleGen { * @return Sibling hashes required for the proof. * @return Flags indicating the source of each proof hash. */ - function prove_single_layer(bytes32[] memory layer, uint256[] memory indices) + function _proveSingleLayer(bytes32[] memory layer, uint256[] memory indices) internal pure returns (uint256[] memory, bytes32[] memory, bool[] memory) @@ -91,13 +189,13 @@ library MerkleGen { while (j < indices.length) { uint256 x = indices[j]; - next_indices = next_indices.append(parent_index(x)); + next_indices = next_indices.append(_getParentIndex(x)); - if (((j + 1) < indices.length) && (indices[j + 1] == sibling_index(x))) { + if (((j + 1) < indices.length) && (indices[j + 1] == _getSiblingIndex(x))) { j += 1; source_flags = source_flags.append(SOURCE_FROM_HASHES); } else { - auth_indices = auth_indices.append(sibling_index(x)); + auth_indices = auth_indices.append(_getSiblingIndex(x)); source_flags = source_flags.append(SOURCE_FROM_PROOF); } j += 1; @@ -122,7 +220,7 @@ library MerkleGen { * @param flag Flag to count. * @return Number of times the flag appears in the array. */ - function helper_count(bool[] memory flags, bool flag) internal pure returns (uint256) { + function _helperCount(bool[] memory flags, bool flag) internal pure returns (uint256) { uint256 count = 0; for (uint256 i = 0; i < flags.length; i++) { if (flags[i] == flag) { @@ -142,7 +240,7 @@ library MerkleGen { * @param proof_source_flags Flags indicating the source of each proof hash. * @return Computed Merkle root. */ - function verify_compute_root( + function _verifyComputeRoot( bytes32[] memory leaves, bytes32[] memory proof_hashes, bool[] memory proof_source_flags @@ -150,7 +248,7 @@ library MerkleGen { uint256 total_hashes = leaves.length + proof_hashes.length - 1; require(total_hashes == proof_source_flags.length, "MerkleGen: Invalid total hashes."); require( - helper_count(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length, + _helperCount(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length, "MerkleGen: Invalid number of proof hashes." ); @@ -192,7 +290,7 @@ library MerkleGen { } // Compute hash - hashes[i] = hash_internal_nodes(a, b); + hashes[i] = _hashLeafPairs(a, b); } if (total_hashes > 0) { @@ -203,73 +301,149 @@ library MerkleGen { } /** - * @notice Generates a Merkle MultiProof for the selected leaves. - * @dev Constructs the necessary proof components and verifies the Merkle root. - * @dev The computed root must match the actual root of the Merkle tree. - * @param hashed_leaves The array of hashed leaves in the Merkle tree. - * @param selected_indexes The indices of the leaves to include in the proof. - * @return Sibling hashes required for the proof. - * @return Flags indicating the source of each proof hash. - * @return Merkle root of the tree. + * @notice Initializes the Merkle tree by placing leaves in the correct positions. + * @dev The tree is represented as a flat array, where the leaves occupy the last `leaves.length` positions. + * @param leaves The array of leaves to be used for the Merkle tree. + * @return A flat array representing the initialized tree, with leaves placed in the correct positions. */ - function gen(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes) - public - pure - returns (bytes32[] memory, bool[] memory, bytes32) - { - bytes32[] memory layer = hashed_leaves.copy(); - // Append with the same leaf if odd number of leaves - if (layer.length % 2 == 1) { - layer = layer.append(layer[layer.length - 1]); + function _initTree(bytes32[] memory leaves) internal pure returns (bytes32[] memory) { + require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1."); + + bytes32[] memory tree = new bytes32[](2 * leaves.length - 1); + + uint256 index = tree.length - leaves.length; + + for (uint256 i = 0; i < leaves.length; i++) { + tree[index + i] = leaves[i]; } - // Create a two dimensional array - bytes32[][] memory layers = new bytes32[][](1); - layers[0] = layer; - bytes32[] memory next_layer; - while (layer.length > 1) { - next_layer = compute_next_layer(layer); - layers = layers.append(next_layer); - layer = next_layer; + + return tree; + } + + /** + * @notice Builds the complete Merkle tree from the given leaves. + * @dev The function computes the parent nodes from the leaves up to the root of the tree. + * @param leaves The array of leaves to build the Merkle tree. + * @return A flat array representing the complete Merkle tree. + */ + function _buildTree(bytes32[] memory leaves) internal pure returns (bytes32[] memory) { + bytes32[] memory tree = _initTree(leaves); + + for (uint256 i = tree.length - 1; i > 1; i -= 2) { + bytes32 left = tree[i - 1]; + bytes32 right = tree[i]; + bytes32 parent = _hashLeafPairs(left, right); + uint256 parentIndex = (i - 1) / 2; + tree[parentIndex] = parent; } - bytes32[] memory proof_hashes; - bool[] memory proof_source_flags; - uint256[] memory indices = selected_indexes.copy(); + return tree; + } - bytes32[] memory subproof; - bool[] memory source_flags; - for (uint256 i = 0; i < layers.length - 1; i++) { - // Exclude the last layer because it's the root - layer = layers[i]; - (indices, subproof, source_flags) = prove_single_layer(layer, indices); - proof_hashes = proof_hashes.extend(subproof); - proof_source_flags = proof_source_flags.extend(source_flags); + /** + * @notice Returns the root hash of the Merkle tree constructed from the given leaves. + * @dev The tree is built and the root (the first element of the tree) is returned. + * @param leaves The array of leaves to build the Merkle tree. + * @return The root hash of the Merkle tree. + */ + function _getRoot(bytes32[] memory leaves) internal pure returns (bytes32) { + require(leaves.length > 1, "MerkleGen: Data should be greater than 1."); + + bytes32[] memory tree = _buildTree(leaves); + + return tree[0]; + } + + /** + * @notice Generates the Merkle proof for a specific leaf index. + * @dev Traverses the tree from the leaf at the specified index to the root, collecting the sibling hashes required for proof. + * @param leaves The array of leaves for the Merkle tree. + * @param index The index of the leaf for which the proof is generated. + * @return An array of sibling hashes forming the Merkle proof for the leaf at the specified index. + */ + function _getProof(bytes32[] memory leaves, uint256 index) internal pure returns (bytes32[] memory) { + require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1."); + + bytes32[] memory tree = _buildTree(leaves); + + uint256 proofLength = _log2CeilBitMagic(leaves.length); + bytes32[] memory proof = new bytes32[](proofLength); + + uint256 proofIndex = 0; + + uint256 currentIndex = leaves.length - 1 + index; + + while (currentIndex > 0) { + uint256 siblingIndex = (currentIndex % 2 == 0) ? currentIndex - 1 : currentIndex + 1; + + if (siblingIndex < tree.length) { + proof[proofIndex] = tree[siblingIndex]; + proofIndex++; + } + + currentIndex = (currentIndex - 1) / 2; } - // Get leaves in hashed_leaves that are in selected_indexes - bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length); - for (uint256 i = 0; i < selected_indexes.length; i++) { - indexed_leaves[i] = hashed_leaves[selected_indexes[i]]; + bytes32[] memory finalProof = new bytes32[](proofIndex); + for (uint256 i = 0; i < proofIndex; i++) { + finalProof[i] = proof[i]; } - bytes32 root = verify_compute_root(indexed_leaves, proof_hashes, proof_source_flags); + return finalProof; + } - // Check if computed root is the same as the root of the tree - require(root == layers[layers.length - 1][0], "Invalid root"); + /** + * @notice Computes the ceiling of the base-2 logarithm of a number using bitwise operations. + * @dev This is an optimized method to compute the log2 value, rounded up to the nearest integer. + * @param x The number for which the log2 ceiling is computed. + * @return The smallest integer greater than or equal to log2(x). + */ + function _log2CeilBitMagic(uint256 x) internal pure returns (uint256) { + if (x <= 1) { + return 0; + } - // Convert proof_source_flags to bits and uint256 - uint256 proof_flag_bits = 0; - bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length); - for (uint256 i = 0; i < proof_source_flags.length; i++) { - if (proof_source_flags[i] == SOURCE_FROM_HASHES) { - proof_flag_bits_bool[i] = true; - proof_flag_bits = proof_flag_bits | (1 << i); - } else { - proof_flag_bits_bool[i] = false; - proof_flag_bits = proof_flag_bits | (0 << i); - } + uint256 msb = 0; + uint256 _x = x; + + if (x >= 2 ** 128) { + x >>= 128; + msb += 128; + } + if (x >= 2 ** 64) { + x >>= 64; + msb += 64; + } + if (x >= 2 ** 32) { + x >>= 32; + msb += 32; + } + if (x >= 2 ** 16) { + x >>= 16; + msb += 16; + } + if (x >= 2 ** 8) { + x >>= 8; + msb += 8; + } + if (x >= 2 ** 4) { + x >>= 4; + msb += 4; + } + if (x >= 2 ** 2) { + x >>= 2; + msb += 2; + } + if (x >= 2 ** 1) { + msb += 1; } - return (proof_hashes, proof_flag_bits_bool, root); + uint256 lsb = (~_x + 1) & _x; + + if ((lsb == _x) && (msb > 0)) { + return msb; + } else { + return msb + 1; + } } } diff --git a/src/Prover.sol b/src/Prover.sol index 44f720b..ad42592 100644 --- a/src/Prover.sol +++ b/src/Prover.sol @@ -6,6 +6,7 @@ import {MerkleProof} from "@openzeppelin/contracts/utils/cryptography/MerkleProo /** * @notice Wrapper library for proving Merkle MultiProofs. * @author sonicskye. + * @author kamuikatsurgi. */ library Prover { /** @@ -17,11 +18,23 @@ library Prover { * @param leaves Leaf nodes that are being proved to be part of the Merkle tree. * @return A boolean value indicating whether the proof is valid or not. */ - function prove(bytes32[] calldata proof, bool[] calldata flag, bytes32 root, bytes32[] calldata leaves) + function proveMultiProof(bytes32[] calldata proof, bool[] calldata flag, bytes32 root, bytes32[] calldata leaves) public pure returns (bool) { return MerkleProof.multiProofVerifyCalldata(proof, flag, root, leaves); } + + /** + * @notice Verifies the validity of a Merkle SingleProof. + * @dev Uses OpenZeppelin's `verifyCalldata` to validate the proof. + * @param proof The array of sibling hashes that help prove the inclusion of the leaf. + * @param root Root hash of the Merkle tree. + * @param leaf Leaf node that is being proved to be a part of the Merkle tree. + * @return A boolean value indicating whether the proof is valid or not. + */ + function proveSingleProof(bytes32[] calldata proof, bytes32 root, bytes32 leaf) public pure returns (bool) { + return MerkleProof.verifyCalldata(proof, root, leaf); + } } diff --git a/test/MerkleGen.t.sol b/test/MerkleGen.t.sol index fd8cd7d..62d4b81 100644 --- a/test/MerkleGen.t.sol +++ b/test/MerkleGen.t.sol @@ -11,7 +11,7 @@ import {Prover} from "../src/Prover.sol"; */ contract MerkleGenTest is Test { /// @dev A test for standard 4-leaf Merkle tree with all known values. - function test_prove() public { + function test_prove_multi_proof_for_standard_4_leaf_merkle_tree() public { // Generate an array of bytes32 leaves bytes32[] memory leaves = new bytes32[](4); leaves[0] = keccak256(abi.encodePacked(uint256(0))); @@ -32,16 +32,17 @@ contract MerkleGenTest is Test { } // Generate the proof - (bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = MerkleGen.gen(leaves, indices); + (bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = + MerkleGen.generateMultiproof(leaves, indices); emit log_named_bytes32("root", root); // Verify the proof - assertTrue(Prover.prove(proof, proofFlagBits, root, leaf_indexes)); + assertTrue(Prover.proveMultiProof(proof, proofFlagBits, root, leaf_indexes)); } - /// @dev A fuzz test for the Merkle tree. - function testFuzz_prove(uint256 seed, bool[] memory select_leaves_, uint256 numLeaves) public pure { + /// @dev A fuzz test for proving MultiProofs for the Merkle tree. + function testFuzz_prove_multi_proof(uint256 seed, bool[] memory select_leaves_, uint256 numLeaves) public pure { //uint256 numLeaves = 5; // Assume numLeaves = bound(numLeaves, 1, 10000); @@ -78,9 +79,53 @@ contract MerkleGenTest is Test { } // Generate the proof - (bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = MerkleGen.gen(leaves, indices); + (bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = + MerkleGen.generateMultiproof(leaves, indices); // Verify the proof - assertTrue(Prover.prove(proof, proofFlagBits, root, leaf_indexes)); + assertTrue(Prover.proveMultiProof(proof, proofFlagBits, root, leaf_indexes)); + } + + /// @dev A test for standard 4-leaf Merkle tree with all known values. + function test_prove_single_proof_for_standard_4_leaf_merkle_tree() public { + // Generate an array of bytes32 leaves + bytes32[] memory leaves = new bytes32[](4); + leaves[0] = keccak256(abi.encodePacked(uint256(0))); + leaves[1] = keccak256(abi.encodePacked(uint256(1))); + leaves[2] = keccak256(abi.encodePacked(uint256(2))); + leaves[3] = keccak256(abi.encodePacked(uint256(3))); + + // Generate the proof and root + (bytes32[] memory proof, bytes32 root) = MerkleGen.generateSingleProof(leaves, 1); + + emit log_named_bytes32("root", root); + + // Verify the proof + assertTrue(Prover.proveSingleProof(proof, root, leaves[1])); + } + + /// @dev A fuzz test for proving SingleProofs for the Merkle tree. + function testFuzz_prove_single_proof(uint256 seed, uint256 numLeaves, uint256 randomLeafIndex) public { + // Assume total number of leaves + numLeaves = bound(numLeaves, 2, 10000); + + // Seed for generating leaves + seed = bound(seed, 1 ether, 1000 ether); + + randomLeafIndex = bound(randomLeafIndex, 0, numLeaves - 1); + + // Generate an array of bytes32 leaves + bytes32[] memory leaves = new bytes32[](numLeaves); + for (uint256 i = 0; i < numLeaves; i++) { + leaves[i] = keccak256(abi.encodePacked(seed + i)); + } + + // Generate the proof + (bytes32[] memory proof, bytes32 root) = MerkleGen.generateSingleProof(leaves, randomLeafIndex); + + emit log_named_bytes32("root", root); + + // Verify the proof + assertTrue(Prover.proveSingleProof(proof, root, leaves[randomLeafIndex])); } }