Skip to content

Commit

Permalink
add a nonce to the l1 proxy and some unit tests for execution
Browse files Browse the repository at this point in the history
  • Loading branch information
moodysalem committed Jan 12, 2025
1 parent 529b0f8 commit a4bb720
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 31 deletions.
26 changes: 20 additions & 6 deletions l1_proxy/src/StarknetOwnerProxy.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,58 @@ interface IStarknetMessaging {
contract StarknetOwnerProxy {
error InvalidTarget();
error InsufficientBalance();
error InvalidNonce(uint64 current, uint64 nonce);
error CallFailed(bytes data);

IStarknetMessaging public immutable l2MessageBridge;
uint256 public immutable l2Owner;

uint64 public currentNonce;

constructor(IStarknetMessaging _l2MessageBridge, uint256 _l2Owner) {
l2MessageBridge = _l2MessageBridge;
l2Owner = _l2Owner;
}

// Returns the payload split into 31-byte chunks,
// ensuring each element is < 2^251
function getPayload(address target, uint256 value, bytes calldata data) public pure returns (uint256[] memory) {
function getPayload(address target, uint256 value, uint64 nonce, bytes calldata data)
public
pure
returns (uint256[] memory)
{
// Each payload element can hold up to 31 bytes since it has to be expressed as felt252 on Starknet
uint256 chunkCount = (data.length + 30) / 31;
uint256[] memory payload = new uint256[](3 + chunkCount);
uint256[] memory payload = new uint256[](4 + chunkCount);

payload[0] = uint256(uint160(target));
payload[1] = value;
payload[2] = data.length;
payload[2] = nonce;
payload[3] = data.length;

for (uint256 i = 0; i < chunkCount; i++) {
assembly ("memory-safe") {
mstore(add(payload, mul(add(i, 4), 32)), shr(8, calldataload(add(data.offset, mul(i, 31)))))
mstore(add(payload, mul(add(i, 5), 32)), shr(8, calldataload(add(data.offset, mul(i, 31)))))
}
}

return payload;
}

function execute(address target, uint256 value, bytes calldata data) external returns (bytes memory) {
function execute(address target, uint256 value, uint64 nonce, bytes calldata data)
external
returns (bytes memory)
{
if (target == address(0) || target == address(this)) {
revert InvalidTarget();
}
if (address(this).balance < value) revert InsufficientBalance();

if (currentNonce != nonce) revert InvalidNonce(currentNonce, nonce);
currentNonce++;

// Consume message from L2. This will fail if the message has not been sent from L2.
l2MessageBridge.consumeMessageFromL2(l2Owner, getPayload(target, value, data));
l2MessageBridge.consumeMessageFromL2(l2Owner, getPayload(target, value, nonce, data));

(bool success, bytes memory result) = target.call{value: value}(data);
if (!success) revert CallFailed(result);
Expand Down
140 changes: 115 additions & 25 deletions l1_proxy/test/StarknetOwnerProxy.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,95 +4,185 @@ pragma solidity =0.8.28;
import {Test, console} from "forge-std/Test.sol";
import {StarknetOwnerProxy, IStarknetMessaging} from "../src/StarknetOwnerProxy.sol";

contract MockStarknetMessaging is IStarknetMessaging {
mapping(bytes32 => uint256) public messageCount;

function getMessageHash(uint256 fromAddress, uint256[] calldata payload) public pure returns (bytes32) {
return keccak256(abi.encodePacked(fromAddress, payload));
}

function setMessageCount(uint256 fromAddress, uint256[] calldata payload, uint256 count) external {
messageCount[getMessageHash(fromAddress, payload)] = count;
}

function consumeMessageFromL2(uint256 fromAddress, uint256[] calldata payload) external returns (bytes32) {
bytes32 messageHash = getMessageHash(fromAddress, payload);
messageCount[messageHash]--;
return messageHash;
}
}

contract TestTarget {
uint256 public x;

error RandomError(uint256 x);

function setX(uint256 _x) external {
x = _x;
}

function reverts() external view {
revert RandomError(x);
}
}

contract StarknetOwnerProxyTest is Test {
uint256 public l2Owner;
MockStarknetMessaging public messaging;
StarknetOwnerProxy public proxy;
TestTarget public target;

function setUp() public {
proxy = new StarknetOwnerProxy(IStarknetMessaging(address(0x1)), 123);
l2Owner = 0xabcdabcdabcd;
messaging = new MockStarknetMessaging();
proxy = new StarknetOwnerProxy(messaging, l2Owner);
target = new TestTarget();
}

function test_get_payload_empty() public view {
uint256[] memory expected = new uint256[](3);
uint256[] memory expected = new uint256[](4);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 0;
assertEq(proxy.getPayload(address(0xdeadbeef), 123, hex""), expected);
expected[2] = 5;
expected[3] = 0;
assertEq(proxy.getPayload(address(0xdeadbeef), 123, 5, hex""), expected);
}

function test_get_payload_one_partial_word() public view {
uint256[] memory expected = new uint256[](4);
uint256[] memory expected = new uint256[](5);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 3;
expected[3] = 0xabcdef << 224;
assertEq(proxy.getPayload(address(0xdeadbeef), 123, hex"abcdef"), expected);
expected[2] = 12;
expected[3] = 3;
expected[4] = 0xabcdef << 224;
assertEq(proxy.getPayload(address(0xdeadbeef), 123, 12, hex"abcdef"), expected);
}

function test_get_payload_31_bytes() public view {
uint256[] memory expected = new uint256[](4);
uint256[] memory expected = new uint256[](5);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 31;
expected[3] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff;
expected[2] = 79;
expected[3] = 31;
expected[4] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff;

assertEq(
proxy.getPayload(
address(0xdeadbeef), 123, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
address(0xdeadbeef), 123, 79, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
),
expected
);
}

function test_get_payload_32_bytes() public view {
uint256[] memory expected = new uint256[](5);
uint256[] memory expected = new uint256[](6);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 32;
expected[3] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff;
expected[4] = 0xff000000000000000000000000000000000000000000000000000000000000;
expected[2] = 555;
expected[3] = 32;
expected[4] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff;
expected[5] = 0xff000000000000000000000000000000000000000000000000000000000000;

assertEq(
proxy.getPayload(
address(0xdeadbeef), 123, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
address(0xdeadbeef), 123, 555, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
),
expected
);
}

function test_get_payload_62_bytes() public view {
uint256[] memory expected = new uint256[](5);
uint256[] memory expected = new uint256[](6);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 62;
expected[3] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd;
expected[2] = 2332;
expected[3] = 62;
expected[4] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd;
expected[5] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd;

assertEq(
proxy.getPayload(
address(0xdeadbeef),
123,
2332,
hex"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd"
),
expected
);
}

function test_get_payload_64_bytes() public view {
uint256[] memory expected = new uint256[](6);
uint256[] memory expected = new uint256[](7);
expected[0] = 0xdeadbeef;
expected[1] = 123;
expected[2] = 64;
expected[3] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd;
expected[4] = 0xef0123456789abcdef0123456789abcdef0123456789abcdef0123456789ab;
expected[5] = 0xcdef << 232;
expected[2] = 9009;
expected[3] = 64;
expected[4] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd;
expected[5] = 0xef0123456789abcdef0123456789abcdef0123456789abcdef0123456789ab;
expected[6] = 0xcdef << 232;

assertEq(
proxy.getPayload(
address(0xdeadbeef),
123,
9009,
hex"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
),
expected
);
}

function test_execute() external {
uint256[] memory payload =
proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));

messaging.setMessageCount(l2Owner, payload, 1);
assertEq(target.x(), 0);
proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));
assertEq(target.x(), 123);
}

function test_execute_twice_fails_nonce() external {
uint256[] memory payload =
proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));

// theoretically it could consume the message twice, but it doesn't due to the nonce check
messaging.setMessageCount(l2Owner, payload, 2);
assertEq(target.x(), 0);
proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));
vm.expectRevert(
abi.encodeWithSelector(StarknetOwnerProxy.InvalidNonce.selector, uint64(1), uint64(0)), address(proxy)
);
proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));
}

function test_execute_fails_no_message() external {
vm.expectRevert(address(messaging));
proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123)));
}

function test_execute_call_fails() external {
uint256[] memory payload =
proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.reverts.selector));

messaging.setMessageCount(l2Owner, payload, 1);

vm.expectRevert(
abi.encodeWithSelector(
StarknetOwnerProxy.CallFailed.selector,
abi.encodeWithSelector(TestTarget.RandomError.selector, uint256(0))
)
);
proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.reverts.selector));
}
}

0 comments on commit a4bb720

Please sign in to comment.