diff --git a/l1_proxy/src/StarknetOwnerProxy.sol b/l1_proxy/src/StarknetOwnerProxy.sol index 0613742..d98cb0b 100644 --- a/l1_proxy/src/StarknetOwnerProxy.sol +++ b/l1_proxy/src/StarknetOwnerProxy.sol @@ -8,11 +8,14 @@ interface IStarknetMessaging { contract StarknetOwnerProxy { error InvalidTarget(); error InsufficientBalance(); + error InvalidNonce(uint64 current, uint64 nonce); error CallFailed(bytes data); IStarknetMessaging public immutable l2MessageBridge; uint256 public immutable l2Owner; + uint64 public currentNonce; + constructor(IStarknetMessaging _l2MessageBridge, uint256 _l2Owner) { l2MessageBridge = _l2MessageBridge; l2Owner = _l2Owner; @@ -20,32 +23,43 @@ contract StarknetOwnerProxy { // Returns the payload split into 31-byte chunks, // ensuring each element is < 2^251 - function getPayload(address target, uint256 value, bytes calldata data) public pure returns (uint256[] memory) { + function getPayload(address target, uint256 value, uint64 nonce, bytes calldata data) + public + pure + returns (uint256[] memory) + { // Each payload element can hold up to 31 bytes since it has to be expressed as felt252 on Starknet uint256 chunkCount = (data.length + 30) / 31; - uint256[] memory payload = new uint256[](3 + chunkCount); + uint256[] memory payload = new uint256[](4 + chunkCount); payload[0] = uint256(uint160(target)); payload[1] = value; - payload[2] = data.length; + payload[2] = nonce; + payload[3] = data.length; for (uint256 i = 0; i < chunkCount; i++) { assembly ("memory-safe") { - mstore(add(payload, mul(add(i, 4), 32)), shr(8, calldataload(add(data.offset, mul(i, 31))))) + mstore(add(payload, mul(add(i, 5), 32)), shr(8, calldataload(add(data.offset, mul(i, 31))))) } } return payload; } - function execute(address target, uint256 value, bytes calldata data) external returns (bytes memory) { + function execute(address target, uint256 value, uint64 nonce, bytes calldata data) + external + returns (bytes memory) + { if (target == address(0) || target == address(this)) { revert InvalidTarget(); } if (address(this).balance < value) revert InsufficientBalance(); + if (currentNonce != nonce) revert InvalidNonce(currentNonce, nonce); + currentNonce++; + // Consume message from L2. This will fail if the message has not been sent from L2. - l2MessageBridge.consumeMessageFromL2(l2Owner, getPayload(target, value, data)); + l2MessageBridge.consumeMessageFromL2(l2Owner, getPayload(target, value, nonce, data)); (bool success, bytes memory result) = target.call{value: value}(data); if (!success) revert CallFailed(result); diff --git a/l1_proxy/test/StarknetOwnerProxy.t.sol b/l1_proxy/test/StarknetOwnerProxy.t.sol index 7c95ab7..06d8d60 100644 --- a/l1_proxy/test/StarknetOwnerProxy.t.sol +++ b/l1_proxy/test/StarknetOwnerProxy.t.sol @@ -4,73 +4,117 @@ pragma solidity =0.8.28; import {Test, console} from "forge-std/Test.sol"; import {StarknetOwnerProxy, IStarknetMessaging} from "../src/StarknetOwnerProxy.sol"; +contract MockStarknetMessaging is IStarknetMessaging { + mapping(bytes32 => uint256) public messageCount; + + function getMessageHash(uint256 fromAddress, uint256[] calldata payload) public pure returns (bytes32) { + return keccak256(abi.encodePacked(fromAddress, payload)); + } + + function setMessageCount(uint256 fromAddress, uint256[] calldata payload, uint256 count) external { + messageCount[getMessageHash(fromAddress, payload)] = count; + } + + function consumeMessageFromL2(uint256 fromAddress, uint256[] calldata payload) external returns (bytes32) { + bytes32 messageHash = getMessageHash(fromAddress, payload); + messageCount[messageHash]--; + return messageHash; + } +} + +contract TestTarget { + uint256 public x; + + error RandomError(uint256 x); + + function setX(uint256 _x) external { + x = _x; + } + + function reverts() external view { + revert RandomError(x); + } +} + contract StarknetOwnerProxyTest is Test { + uint256 public l2Owner; + MockStarknetMessaging public messaging; StarknetOwnerProxy public proxy; + TestTarget public target; function setUp() public { - proxy = new StarknetOwnerProxy(IStarknetMessaging(address(0x1)), 123); + l2Owner = 0xabcdabcdabcd; + messaging = new MockStarknetMessaging(); + proxy = new StarknetOwnerProxy(messaging, l2Owner); + target = new TestTarget(); } function test_get_payload_empty() public view { - uint256[] memory expected = new uint256[](3); + uint256[] memory expected = new uint256[](4); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 0; - assertEq(proxy.getPayload(address(0xdeadbeef), 123, hex""), expected); + expected[2] = 5; + expected[3] = 0; + assertEq(proxy.getPayload(address(0xdeadbeef), 123, 5, hex""), expected); } function test_get_payload_one_partial_word() public view { - uint256[] memory expected = new uint256[](4); + uint256[] memory expected = new uint256[](5); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 3; - expected[3] = 0xabcdef << 224; - assertEq(proxy.getPayload(address(0xdeadbeef), 123, hex"abcdef"), expected); + expected[2] = 12; + expected[3] = 3; + expected[4] = 0xabcdef << 224; + assertEq(proxy.getPayload(address(0xdeadbeef), 123, 12, hex"abcdef"), expected); } function test_get_payload_31_bytes() public view { - uint256[] memory expected = new uint256[](4); + uint256[] memory expected = new uint256[](5); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 31; - expected[3] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff; + expected[2] = 79; + expected[3] = 31; + expected[4] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff; assertEq( proxy.getPayload( - address(0xdeadbeef), 123, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + address(0xdeadbeef), 123, 79, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" ), expected ); } function test_get_payload_32_bytes() public view { - uint256[] memory expected = new uint256[](5); + uint256[] memory expected = new uint256[](6); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 32; - expected[3] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff; - expected[4] = 0xff000000000000000000000000000000000000000000000000000000000000; + expected[2] = 555; + expected[3] = 32; + expected[4] = 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff; + expected[5] = 0xff000000000000000000000000000000000000000000000000000000000000; assertEq( proxy.getPayload( - address(0xdeadbeef), 123, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" + address(0xdeadbeef), 123, 555, hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff" ), expected ); } function test_get_payload_62_bytes() public view { - uint256[] memory expected = new uint256[](5); + uint256[] memory expected = new uint256[](6); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 62; - expected[3] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd; + expected[2] = 2332; + expected[3] = 62; expected[4] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd; + expected[5] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd; assertEq( proxy.getPayload( address(0xdeadbeef), 123, + 2332, hex"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd" ), expected @@ -78,21 +122,67 @@ contract StarknetOwnerProxyTest is Test { } function test_get_payload_64_bytes() public view { - uint256[] memory expected = new uint256[](6); + uint256[] memory expected = new uint256[](7); expected[0] = 0xdeadbeef; expected[1] = 123; - expected[2] = 64; - expected[3] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd; - expected[4] = 0xef0123456789abcdef0123456789abcdef0123456789abcdef0123456789ab; - expected[5] = 0xcdef << 232; + expected[2] = 9009; + expected[3] = 64; + expected[4] = 0x0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcd; + expected[5] = 0xef0123456789abcdef0123456789abcdef0123456789abcdef0123456789ab; + expected[6] = 0xcdef << 232; assertEq( proxy.getPayload( address(0xdeadbeef), 123, + 9009, hex"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" ), expected ); } + + function test_execute() external { + uint256[] memory payload = + proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + + messaging.setMessageCount(l2Owner, payload, 1); + assertEq(target.x(), 0); + proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + assertEq(target.x(), 123); + } + + function test_execute_twice_fails_nonce() external { + uint256[] memory payload = + proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + + // theoretically it could consume the message twice, but it doesn't due to the nonce check + messaging.setMessageCount(l2Owner, payload, 2); + assertEq(target.x(), 0); + proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + vm.expectRevert( + abi.encodeWithSelector(StarknetOwnerProxy.InvalidNonce.selector, uint64(1), uint64(0)), address(proxy) + ); + proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + } + + function test_execute_fails_no_message() external { + vm.expectRevert(address(messaging)); + proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.setX.selector, (123))); + } + + function test_execute_call_fails() external { + uint256[] memory payload = + proxy.getPayload(address(target), 0, 0, abi.encodeWithSelector(TestTarget.reverts.selector)); + + messaging.setMessageCount(l2Owner, payload, 1); + + vm.expectRevert( + abi.encodeWithSelector( + StarknetOwnerProxy.CallFailed.selector, + abi.encodeWithSelector(TestTarget.RandomError.selector, uint256(0)) + ) + ); + proxy.execute(address(target), 0, 0, abi.encodeWithSelector(TestTarget.reverts.selector)); + } }