Skip to content

Commit

Permalink
feat: suport for single proofs (#7)
Browse files Browse the repository at this point in the history
* feat: single proofs

* minor nit

* addressing comments

* prefix internal function with _
kamuikatsurgi authored Oct 3, 2024
1 parent 234d05d commit 1cf989b
Showing 3 changed files with 314 additions and 82 deletions.
322 changes: 248 additions & 74 deletions src/MerkleGen.sol
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import {ArrayLib} from "./libraries/ArrayLib.sol";
/**
* @notice Library for generating Merkle MultiProofs.
* @author sonicskye.
* @author kamuikatsurgi.
*/
library MerkleGen {
using ArrayLib for *;
@@ -14,12 +15,109 @@ library MerkleGen {
bool private constant SOURCE_FROM_PROOF = false;

/**
* @notice Hashes two internal nodes to generate their parent node.
* @notice Generates a Merkle MultiProof for the selected leaves.
* @dev Constructs the necessary proof components and verifies the Merkle root.
* @dev The computed root must match the actual root of the Merkle tree.
* @param hashed_leaves The array of hashed leaves in the Merkle tree.
* @param selected_indexes The indices of the leaves to include in the proof.
* @return Sibling hashes required for the proof.
* @return Flags indicating the source of each proof hash.
* @return Merkle root of the tree.
*/
function generateMultiproof(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes)
public
pure
returns (bytes32[] memory, bool[] memory, bytes32)
{
bytes32[] memory layer = hashed_leaves.copy();
// Append with the same leaf if odd number of leaves
if (layer.length % 2 == 1) {
layer = layer.append(layer[layer.length - 1]);
}
// Create a two dimensional array
bytes32[][] memory layers = new bytes32[][](1);
layers[0] = layer;
bytes32[] memory parent_layer;
while (layer.length > 1) {
parent_layer = _computeParentLayer(layer);
layers = layers.append(parent_layer);
layer = parent_layer;
}

bytes32[] memory proof_hashes;
bool[] memory proof_source_flags;
uint256[] memory indices = selected_indexes.copy();

bytes32[] memory subproof;
bool[] memory source_flags;
for (uint256 i = 0; i < layers.length - 1; i++) {
// Exclude the last layer because it's the root
layer = layers[i];
(indices, subproof, source_flags) = _proveSingleLayer(layer, indices);
proof_hashes = proof_hashes.extend(subproof);
proof_source_flags = proof_source_flags.extend(source_flags);
}

// Get leaves in hashed_leaves that are in selected_indexes
bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length);
for (uint256 i = 0; i < selected_indexes.length; i++) {
indexed_leaves[i] = hashed_leaves[selected_indexes[i]];
}

bytes32 root = _verifyComputeRoot(indexed_leaves, proof_hashes, proof_source_flags);

// Check if computed root is the same as the root of the tree
require(root == layers[layers.length - 1][0], "Invalid root");

// Convert proof_source_flags to bits and uint256
uint256 proof_flag_bits = 0;
bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length);
for (uint256 i = 0; i < proof_source_flags.length; i++) {
if (proof_source_flags[i] == SOURCE_FROM_HASHES) {
proof_flag_bits_bool[i] = true;
proof_flag_bits = proof_flag_bits | (1 << i);
} else {
proof_flag_bits_bool[i] = false;
proof_flag_bits = proof_flag_bits | (0 << i);
}
}

return (proof_hashes, proof_flag_bits_bool, root);
}

/**
* @notice Generates a Merkle proof for a single leaf in the Merkle tree.
* @dev The function computes the proof and the root of the Merkle tree.
* @param leaves The array of leaves used to build the Merkle tree.
* @param leafIndex The index of the leaf for which the proof is generated.
* @return proof An array of sibling hashes forming the Merkle proof for the leaf.
* @return root The root hash of the Merkle tree.
*/
function generateSingleProof(bytes32[] memory leaves, uint256 leafIndex)
public
pure
returns (bytes32[] memory, bytes32)
{
require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1.");

// Append with the same leaf if odd number of leaves
if (leaves.length % 2 == 1) {
leaves = leaves.append(leaves[leaves.length - 1]);
}

bytes32[] memory proof = _getProof(leaves, leafIndex);
bytes32 root = _getRoot(leaves);

return (proof, root);
}

/**
* @notice Hashes two leaf nodes to generate their parent node.
* @param a First child node.
* @param b Second child node.
* @return h Hashed parent node.
*/
function hash_internal_nodes(bytes32 a, bytes32 b) internal pure returns (bytes32 h) {
function _hashLeafPairs(bytes32 a, bytes32 b) internal pure returns (bytes32 h) {
if (a < b) {
h = keccak256(abi.encodePacked(a, b));
} else {
@@ -28,12 +126,12 @@ library MerkleGen {
}

/**
* @notice Computes the next layer in the Merkle tree from the current layer.
* @notice Computes the parent layer in the Merkle tree from the current layer.
* @dev If the current layer has an odd number of nodes, the last node is duplicated.
* @param layer Current layer of the Merkle tree.
* @return Computed next layer.
* @return Computed parent layer.
*/
function compute_next_layer(bytes32[] memory layer) internal pure returns (bytes32[] memory) {
function _computeParentLayer(bytes32[] memory layer) internal pure returns (bytes32[] memory) {
if (layer.length == 1) {
return layer;
}
@@ -43,30 +141,30 @@ library MerkleGen {
layer = layer.append(layer[layer.length - 1]);
}

bytes32[] memory next_layer;
bytes32[] memory parent_layer;

for (uint256 i = 0; i < layer.length; i += 2) {
next_layer = next_layer.append(hash_internal_nodes(layer[i], layer[i + 1]));
parent_layer = parent_layer.append(_hashLeafPairs(layer[i], layer[i + 1]));
}

return next_layer;
return parent_layer;
}

/**
* @notice Calculates the parent index for a given node index.
* @param index Current node index.
* @return Parent node index.
*/
function parent_index(uint256 index) internal pure returns (uint256) {
function _getParentIndex(uint256 index) internal pure returns (uint256) {
return index / 2;
}

/**
* @notice Determines the sibling index of a given node index.
* @notice Calculates the sibling index of a given node index.
* @param index Current node index.
* @return Sibling node index.
*/
function sibling_index(uint256 index) internal pure returns (uint256) {
function _getSiblingIndex(uint256 index) internal pure returns (uint256) {
return index ^ 1;
}

@@ -79,7 +177,7 @@ library MerkleGen {
* @return Sibling hashes required for the proof.
* @return Flags indicating the source of each proof hash.
*/
function prove_single_layer(bytes32[] memory layer, uint256[] memory indices)
function _proveSingleLayer(bytes32[] memory layer, uint256[] memory indices)
internal
pure
returns (uint256[] memory, bytes32[] memory, bool[] memory)
@@ -91,13 +189,13 @@ library MerkleGen {

while (j < indices.length) {
uint256 x = indices[j];
next_indices = next_indices.append(parent_index(x));
next_indices = next_indices.append(_getParentIndex(x));

if (((j + 1) < indices.length) && (indices[j + 1] == sibling_index(x))) {
if (((j + 1) < indices.length) && (indices[j + 1] == _getSiblingIndex(x))) {
j += 1;
source_flags = source_flags.append(SOURCE_FROM_HASHES);
} else {
auth_indices = auth_indices.append(sibling_index(x));
auth_indices = auth_indices.append(_getSiblingIndex(x));
source_flags = source_flags.append(SOURCE_FROM_PROOF);
}
j += 1;
@@ -122,7 +220,7 @@ library MerkleGen {
* @param flag Flag to count.
* @return Number of times the flag appears in the array.
*/
function helper_count(bool[] memory flags, bool flag) internal pure returns (uint256) {
function _helperCount(bool[] memory flags, bool flag) internal pure returns (uint256) {
uint256 count = 0;
for (uint256 i = 0; i < flags.length; i++) {
if (flags[i] == flag) {
@@ -142,15 +240,15 @@ library MerkleGen {
* @param proof_source_flags Flags indicating the source of each proof hash.
* @return Computed Merkle root.
*/
function verify_compute_root(
function _verifyComputeRoot(
bytes32[] memory leaves,
bytes32[] memory proof_hashes,
bool[] memory proof_source_flags
) internal pure returns (bytes32) {
uint256 total_hashes = leaves.length + proof_hashes.length - 1;
require(total_hashes == proof_source_flags.length, "MerkleGen: Invalid total hashes.");
require(
helper_count(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length,
_helperCount(proof_source_flags, SOURCE_FROM_PROOF) == proof_hashes.length,
"MerkleGen: Invalid number of proof hashes."
);

@@ -192,7 +290,7 @@ library MerkleGen {
}

// Compute hash
hashes[i] = hash_internal_nodes(a, b);
hashes[i] = _hashLeafPairs(a, b);
}

if (total_hashes > 0) {
@@ -203,73 +301,149 @@ library MerkleGen {
}

/**
* @notice Generates a Merkle MultiProof for the selected leaves.
* @dev Constructs the necessary proof components and verifies the Merkle root.
* @dev The computed root must match the actual root of the Merkle tree.
* @param hashed_leaves The array of hashed leaves in the Merkle tree.
* @param selected_indexes The indices of the leaves to include in the proof.
* @return Sibling hashes required for the proof.
* @return Flags indicating the source of each proof hash.
* @return Merkle root of the tree.
* @notice Initializes the Merkle tree by placing leaves in the correct positions.
* @dev The tree is represented as a flat array, where the leaves occupy the last `leaves.length` positions.
* @param leaves The array of leaves to be used for the Merkle tree.
* @return A flat array representing the initialized tree, with leaves placed in the correct positions.
*/
function gen(bytes32[] memory hashed_leaves, uint256[] memory selected_indexes)
public
pure
returns (bytes32[] memory, bool[] memory, bytes32)
{
bytes32[] memory layer = hashed_leaves.copy();
// Append with the same leaf if odd number of leaves
if (layer.length % 2 == 1) {
layer = layer.append(layer[layer.length - 1]);
function _initTree(bytes32[] memory leaves) internal pure returns (bytes32[] memory) {
require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1.");

bytes32[] memory tree = new bytes32[](2 * leaves.length - 1);

uint256 index = tree.length - leaves.length;

for (uint256 i = 0; i < leaves.length; i++) {
tree[index + i] = leaves[i];
}
// Create a two dimensional array
bytes32[][] memory layers = new bytes32[][](1);
layers[0] = layer;
bytes32[] memory next_layer;
while (layer.length > 1) {
next_layer = compute_next_layer(layer);
layers = layers.append(next_layer);
layer = next_layer;

return tree;
}

/**
* @notice Builds the complete Merkle tree from the given leaves.
* @dev The function computes the parent nodes from the leaves up to the root of the tree.
* @param leaves The array of leaves to build the Merkle tree.
* @return A flat array representing the complete Merkle tree.
*/
function _buildTree(bytes32[] memory leaves) internal pure returns (bytes32[] memory) {
bytes32[] memory tree = _initTree(leaves);

for (uint256 i = tree.length - 1; i > 1; i -= 2) {
bytes32 left = tree[i - 1];
bytes32 right = tree[i];
bytes32 parent = _hashLeafPairs(left, right);
uint256 parentIndex = (i - 1) / 2;
tree[parentIndex] = parent;
}

bytes32[] memory proof_hashes;
bool[] memory proof_source_flags;
uint256[] memory indices = selected_indexes.copy();
return tree;
}

bytes32[] memory subproof;
bool[] memory source_flags;
for (uint256 i = 0; i < layers.length - 1; i++) {
// Exclude the last layer because it's the root
layer = layers[i];
(indices, subproof, source_flags) = prove_single_layer(layer, indices);
proof_hashes = proof_hashes.extend(subproof);
proof_source_flags = proof_source_flags.extend(source_flags);
/**
* @notice Returns the root hash of the Merkle tree constructed from the given leaves.
* @dev The tree is built and the root (the first element of the tree) is returned.
* @param leaves The array of leaves to build the Merkle tree.
* @return The root hash of the Merkle tree.
*/
function _getRoot(bytes32[] memory leaves) internal pure returns (bytes32) {
require(leaves.length > 1, "MerkleGen: Data should be greater than 1.");

bytes32[] memory tree = _buildTree(leaves);

return tree[0];
}

/**
* @notice Generates the Merkle proof for a specific leaf index.
* @dev Traverses the tree from the leaf at the specified index to the root, collecting the sibling hashes required for proof.
* @param leaves The array of leaves for the Merkle tree.
* @param index The index of the leaf for which the proof is generated.
* @return An array of sibling hashes forming the Merkle proof for the leaf at the specified index.
*/
function _getProof(bytes32[] memory leaves, uint256 index) internal pure returns (bytes32[] memory) {
require(leaves.length > 1, "MerkleGen: Leaves should be greater than 1.");

bytes32[] memory tree = _buildTree(leaves);

uint256 proofLength = _log2CeilBitMagic(leaves.length);
bytes32[] memory proof = new bytes32[](proofLength);

uint256 proofIndex = 0;

uint256 currentIndex = leaves.length - 1 + index;

while (currentIndex > 0) {
uint256 siblingIndex = (currentIndex % 2 == 0) ? currentIndex - 1 : currentIndex + 1;

if (siblingIndex < tree.length) {
proof[proofIndex] = tree[siblingIndex];
proofIndex++;
}

currentIndex = (currentIndex - 1) / 2;
}

// Get leaves in hashed_leaves that are in selected_indexes
bytes32[] memory indexed_leaves = new bytes32[](selected_indexes.length);
for (uint256 i = 0; i < selected_indexes.length; i++) {
indexed_leaves[i] = hashed_leaves[selected_indexes[i]];
bytes32[] memory finalProof = new bytes32[](proofIndex);
for (uint256 i = 0; i < proofIndex; i++) {
finalProof[i] = proof[i];
}

bytes32 root = verify_compute_root(indexed_leaves, proof_hashes, proof_source_flags);
return finalProof;
}

// Check if computed root is the same as the root of the tree
require(root == layers[layers.length - 1][0], "Invalid root");
/**
* @notice Computes the ceiling of the base-2 logarithm of a number using bitwise operations.
* @dev This is an optimized method to compute the log2 value, rounded up to the nearest integer.
* @param x The number for which the log2 ceiling is computed.
* @return The smallest integer greater than or equal to log2(x).
*/
function _log2CeilBitMagic(uint256 x) internal pure returns (uint256) {
if (x <= 1) {
return 0;
}

// Convert proof_source_flags to bits and uint256
uint256 proof_flag_bits = 0;
bool[] memory proof_flag_bits_bool = new bool[](proof_source_flags.length);
for (uint256 i = 0; i < proof_source_flags.length; i++) {
if (proof_source_flags[i] == SOURCE_FROM_HASHES) {
proof_flag_bits_bool[i] = true;
proof_flag_bits = proof_flag_bits | (1 << i);
} else {
proof_flag_bits_bool[i] = false;
proof_flag_bits = proof_flag_bits | (0 << i);
}
uint256 msb = 0;
uint256 _x = x;

if (x >= 2 ** 128) {
x >>= 128;
msb += 128;
}
if (x >= 2 ** 64) {
x >>= 64;
msb += 64;
}
if (x >= 2 ** 32) {
x >>= 32;
msb += 32;
}
if (x >= 2 ** 16) {
x >>= 16;
msb += 16;
}
if (x >= 2 ** 8) {
x >>= 8;
msb += 8;
}
if (x >= 2 ** 4) {
x >>= 4;
msb += 4;
}
if (x >= 2 ** 2) {
x >>= 2;
msb += 2;
}
if (x >= 2 ** 1) {
msb += 1;
}

return (proof_hashes, proof_flag_bits_bool, root);
uint256 lsb = (~_x + 1) & _x;

if ((lsb == _x) && (msb > 0)) {
return msb;
} else {
return msb + 1;
}
}
}
15 changes: 14 additions & 1 deletion src/Prover.sol
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import {MerkleProof} from "@openzeppelin/contracts/utils/cryptography/MerkleProo
/**
* @notice Wrapper library for proving Merkle MultiProofs.
* @author sonicskye.
* @author kamuikatsurgi.
*/
library Prover {
/**
@@ -17,11 +18,23 @@ library Prover {
* @param leaves Leaf nodes that are being proved to be part of the Merkle tree.
* @return A boolean value indicating whether the proof is valid or not.
*/
function prove(bytes32[] calldata proof, bool[] calldata flag, bytes32 root, bytes32[] calldata leaves)
function proveMultiProof(bytes32[] calldata proof, bool[] calldata flag, bytes32 root, bytes32[] calldata leaves)
public
pure
returns (bool)
{
return MerkleProof.multiProofVerifyCalldata(proof, flag, root, leaves);
}

/**
* @notice Verifies the validity of a Merkle SingleProof.
* @dev Uses OpenZeppelin's `verifyCalldata` to validate the proof.
* @param proof The array of sibling hashes that help prove the inclusion of the leaf.
* @param root Root hash of the Merkle tree.
* @param leaf Leaf node that is being proved to be a part of the Merkle tree.
* @return A boolean value indicating whether the proof is valid or not.
*/
function proveSingleProof(bytes32[] calldata proof, bytes32 root, bytes32 leaf) public pure returns (bool) {
return MerkleProof.verifyCalldata(proof, root, leaf);
}
}
59 changes: 52 additions & 7 deletions test/MerkleGen.t.sol
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ import {Prover} from "../src/Prover.sol";
*/
contract MerkleGenTest is Test {
/// @dev A test for standard 4-leaf Merkle tree with all known values.
function test_prove() public {
function test_prove_multi_proof_for_standard_4_leaf_merkle_tree() public {
// Generate an array of bytes32 leaves
bytes32[] memory leaves = new bytes32[](4);
leaves[0] = keccak256(abi.encodePacked(uint256(0)));
@@ -32,16 +32,17 @@ contract MerkleGenTest is Test {
}

// Generate the proof
(bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = MerkleGen.gen(leaves, indices);
(bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) =
MerkleGen.generateMultiproof(leaves, indices);

emit log_named_bytes32("root", root);

// Verify the proof
assertTrue(Prover.prove(proof, proofFlagBits, root, leaf_indexes));
assertTrue(Prover.proveMultiProof(proof, proofFlagBits, root, leaf_indexes));
}

/// @dev A fuzz test for the Merkle tree.
function testFuzz_prove(uint256 seed, bool[] memory select_leaves_, uint256 numLeaves) public pure {
/// @dev A fuzz test for proving MultiProofs for the Merkle tree.
function testFuzz_prove_multi_proof(uint256 seed, bool[] memory select_leaves_, uint256 numLeaves) public pure {
//uint256 numLeaves = 5;
// Assume
numLeaves = bound(numLeaves, 1, 10000);
@@ -78,9 +79,53 @@ contract MerkleGenTest is Test {
}

// Generate the proof
(bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) = MerkleGen.gen(leaves, indices);
(bytes32[] memory proof, bool[] memory proofFlagBits, bytes32 root) =
MerkleGen.generateMultiproof(leaves, indices);

// Verify the proof
assertTrue(Prover.prove(proof, proofFlagBits, root, leaf_indexes));
assertTrue(Prover.proveMultiProof(proof, proofFlagBits, root, leaf_indexes));
}

/// @dev A test for standard 4-leaf Merkle tree with all known values.
function test_prove_single_proof_for_standard_4_leaf_merkle_tree() public {
// Generate an array of bytes32 leaves
bytes32[] memory leaves = new bytes32[](4);
leaves[0] = keccak256(abi.encodePacked(uint256(0)));
leaves[1] = keccak256(abi.encodePacked(uint256(1)));
leaves[2] = keccak256(abi.encodePacked(uint256(2)));
leaves[3] = keccak256(abi.encodePacked(uint256(3)));

// Generate the proof and root
(bytes32[] memory proof, bytes32 root) = MerkleGen.generateSingleProof(leaves, 1);

emit log_named_bytes32("root", root);

// Verify the proof
assertTrue(Prover.proveSingleProof(proof, root, leaves[1]));
}

/// @dev A fuzz test for proving SingleProofs for the Merkle tree.
function testFuzz_prove_single_proof(uint256 seed, uint256 numLeaves, uint256 randomLeafIndex) public {
// Assume total number of leaves
numLeaves = bound(numLeaves, 2, 10000);

// Seed for generating leaves
seed = bound(seed, 1 ether, 1000 ether);

randomLeafIndex = bound(randomLeafIndex, 0, numLeaves - 1);

// Generate an array of bytes32 leaves
bytes32[] memory leaves = new bytes32[](numLeaves);
for (uint256 i = 0; i < numLeaves; i++) {
leaves[i] = keccak256(abi.encodePacked(seed + i));
}

// Generate the proof
(bytes32[] memory proof, bytes32 root) = MerkleGen.generateSingleProof(leaves, randomLeafIndex);

emit log_named_bytes32("root", root);

// Verify the proof
assertTrue(Prover.proveSingleProof(proof, root, leaves[randomLeafIndex]));
}
}

0 comments on commit 1cf989b

Please sign in to comment.