diff --git a/.husky/pre-commit b/.husky/pre-commit old mode 100644 new mode 100755 diff --git a/src/CommunityStakingModule.sol b/src/CommunityStakingModule.sol index 80fb8c6a..0b81d96a 100644 --- a/src/CommunityStakingModule.sol +++ b/src/CommunityStakingModule.sol @@ -3,10 +3,16 @@ pragma solidity 0.8.21; +import { SafeCast } from "@openzeppelin/contracts/utils/math/SafeCast.sol"; +import { Math } from "@openzeppelin/contracts/utils/math/Math.sol"; + +import { ICommunityStakingBondManager } from "./interfaces/ICommunityStakingBondManager.sol"; import { IStakingModule } from "./interfaces/IStakingModule.sol"; -import "./interfaces/ICommunityStakingBondManager.sol"; -import "./interfaces/ILidoLocator.sol"; -import "./interfaces/ILido.sol"; +import { ILidoLocator } from "./interfaces/ILidoLocator.sol"; +import { ILido } from "./interfaces/ILido.sol"; + +import { QueueLib } from "./lib/QueueLib.sol"; +import { Batch } from "./lib/Batch.sol"; import "./lib/SigningKeys.sol"; import "./lib/StringToUint256WithZeroMap.sol"; @@ -36,26 +42,48 @@ contract CommunityStakingModuleBase { ); event NodeOperatorNameSet(uint256 indexed nodeOperatorId, string name); - event VettedKeysCountChanged( + event VettedSigningKeysCountChanged( uint256 indexed nodeOperatorId, - uint256 approvedKeysCount + uint256 approvedValidatorsCount ); - event DepositedKeysCountChanged( + event DepositedSigningKeysCountChanged( uint256 indexed nodeOperatorId, - uint256 depositedKeysCount + uint256 depositedValidatorsCount ); - event ExitedKeysCountChanged( + event ExitedSigningKeysCountChanged( uint256 indexed nodeOperatorId, - uint256 exitedKeysCount + uint256 exitedValidatorsCount ); - event TotalKeysCountChanged( + event TotalSigningKeysCountChanged( + uint256 indexed nodeOperatorId, + uint256 totalValidatorsCount + ); + + event BatchEnqueued( uint256 indexed nodeOperatorId, - uint256 totalKeysCount + uint256 startIndex, + uint256 count ); + + event StakingModuleTypeSet(bytes32 moduleType); + event LocatorContractSet(address locatorAddress); + event UnvettingFeeSet(uint256 unvettingFee); } contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { using StringToUint256WithZeroMap for mapping(string => uint256); + using QueueLib for QueueLib.Queue; + + uint256 public constant MAX_NODE_OPERATOR_NAME_LENGTH = 255; + bytes32 public constant SIGNING_KEYS_POSITION = + keccak256("lido.CommunityStakingModule.signingKeysPosition"); + + uint256 public unvettingFee; + QueueLib.Queue public queue; + + ICommunityStakingBondManager public bondManager; + ILidoLocator public lidoLocator; + uint256 private nodeOperatorsCount; uint256 private activeNodeOperatorsCount; bytes32 private moduleType; @@ -63,43 +91,55 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { mapping(uint256 => NodeOperator) private nodeOperators; mapping(string => uint256) private nodeOperatorIdsByName; - bytes32 public constant SIGNING_KEYS_POSITION = - keccak256("lido.CommunityStakingModule.signingKeysPosition"); - uint256 public constant MAX_NODE_OPERATOR_NAME_LENGTH = 255; + uint256 private _totalDepositedValidators; + uint256 private _totalExitedValidators; + uint256 private _totalAddedValidators; - address public bondManagerAddress; - address public lidoLocator; + modifier onlyActiveNodeOperator(uint256 _nodeOperatorId) { + require( + _nodeOperatorId < nodeOperatorsCount, + "node operator does not exist" + ); + require( + nodeOperators[_nodeOperatorId].active, + "node operator is not active" + ); + _; + } + + modifier onlyKeyValidatorOrNodeOperatorManager() { + // TODO: check the role + _; + } + + modifier onlyKeyValidator() { + // TODO: check the role + _; + } constructor(bytes32 _type, address _locator) { moduleType = _type; + emit StakingModuleTypeSet(_type); require(_locator != address(0), "lido locator is zero address"); - lidoLocator = _locator; - } - - function setBondManager(address _bondManagerAddress) external { - // TODO add role check - require( - address(bondManagerAddress) == address(0), - "already initialized" - ); - bondManagerAddress = _bondManagerAddress; + lidoLocator = ILidoLocator(_locator); + emit LocatorContractSet(_locator); } - function _bondManager() - internal - view - returns (ICommunityStakingBondManager) - { - return ICommunityStakingBondManager(bondManagerAddress); + function setBondManager(address _bondManager) external { + // TODO: add role check + require(address(bondManager) == address(0), "already initialized"); + bondManager = ICommunityStakingBondManager(_bondManager); } - function _lidoLocator() internal view returns (ILidoLocator) { - return ILidoLocator(lidoLocator); + function setUnvettingFee(uint256 unvettingFee_) external { + // TODO: add role check + unvettingFee = unvettingFee_; + emit UnvettingFeeSet(unvettingFee_); } function _lido() internal view returns (ILido) { - return ILido(_lidoLocator().lido()); + return ILido(lidoLocator.lido()); } function getType() external view returns (bytes32) { @@ -110,22 +150,15 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { external view returns ( - uint256 totalExitedValidators, - uint256 totalDepositedValidators, - uint256 depositableValidatorsCount + uint256 /* totalExitedValidators */, + uint256 /* totalDepositedValidators */, + uint256 /* depositableValidatorsCount */ ) { - for (uint256 i = 0; i < nodeOperatorsCount; i++) { - totalExitedValidators += nodeOperators[i].totalExitedKeys; - totalDepositedValidators += nodeOperators[i].totalDepositedKeys; - depositableValidatorsCount += - nodeOperators[i].totalAddedKeys - - nodeOperators[i].totalExitedKeys; - } return ( - totalExitedValidators, - totalDepositedValidators, - depositableValidatorsCount + _totalExitedValidators, + _totalDepositedValidators, + _totalAddedValidators - _totalExitedValidators ); } @@ -140,7 +173,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { _onlyValidNodeOperatorName(_name); require( - msg.value == _bondManager().getRequiredBondETHForKeys(_keysCount), + msg.value == bondManager.getRequiredBondETHForKeys(_keysCount), "eth value is not equal to required bond" ); @@ -153,7 +186,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { nodeOperatorsCount++; activeNodeOperatorsCount++; - _bondManager().depositETH{ value: msg.value }(msg.sender, id); + bondManager.depositETH{ value: msg.value }(msg.sender, id); _addSigningKeys(id, _keysCount, _publicKeys, _signatures); @@ -167,7 +200,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { bytes calldata _publicKeys, bytes calldata _signatures ) external { - // TODO sanity checks + // TODO: sanity checks _onlyValidNodeOperatorName(_name); uint256 id = nodeOperatorsCount; @@ -179,10 +212,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { nodeOperatorsCount++; activeNodeOperatorsCount++; - _bondManager().depositStETH( + bondManager.depositStETH( msg.sender, id, - _bondManager().getRequiredBondStETHForKeys(_keysCount) + bondManager.getRequiredBondStETHForKeys(_keysCount) ); _addSigningKeys(id, _keysCount, _publicKeys, _signatures); @@ -252,10 +285,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { nodeOperatorsCount++; activeNodeOperatorsCount++; - _bondManager().depositStETHWithPermit( + bondManager.depositStETHWithPermit( _from, id, - _bondManager().getRequiredBondStETHForKeys(_keysCount), + bondManager.getRequiredBondStETHForKeys(_keysCount), _permit ); @@ -283,10 +316,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { nodeOperatorsCount++; activeNodeOperatorsCount++; - _bondManager().depositWstETH( + bondManager.depositWstETH( msg.sender, id, - _bondManager().getRequiredBondWstETHForKeys(_keysCount) + bondManager.getRequiredBondWstETHForKeys(_keysCount) ); _addSigningKeys(id, _keysCount, _publicKeys, _signatures); @@ -356,10 +389,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { nodeOperatorsCount++; activeNodeOperatorsCount++; - _bondManager().depositWstETHWithPermit( + bondManager.depositWstETHWithPermit( _from, id, - _bondManager().getRequiredBondWstETHForKeys(_keysCount), + bondManager.getRequiredBondWstETHForKeys(_keysCount), _permit ); @@ -374,19 +407,15 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { bytes calldata _publicKeys, bytes calldata _signatures ) external payable onlyExistingNodeOperator(_nodeOperatorId) { - // TODO sanity checks - // TODO store keys + // TODO: sanity checks require( msg.value == - _bondManager().getRequiredBondETH(_nodeOperatorId, _keysCount), + bondManager.getRequiredBondETH(_nodeOperatorId, _keysCount), "eth value is not equal to required bond" ); - _bondManager().depositETH{ value: msg.value }( - msg.sender, - _nodeOperatorId - ); + bondManager.depositETH{ value: msg.value }(msg.sender, _nodeOperatorId); _addSigningKeys(_nodeOperatorId, _keysCount, _publicKeys, _signatures); } @@ -397,13 +426,12 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { bytes calldata _publicKeys, bytes calldata _signatures ) external onlyExistingNodeOperator(_nodeOperatorId) { - // TODO sanity checks - // TODO store keys + // TODO: sanity checks - _bondManager().depositStETH( + bondManager.depositStETH( msg.sender, _nodeOperatorId, - _bondManager().getRequiredBondStETH(_nodeOperatorId, _keysCount) + bondManager.getRequiredBondStETH(_nodeOperatorId, _keysCount) ); _addSigningKeys(_nodeOperatorId, _keysCount, _publicKeys, _signatures); @@ -457,10 +485,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { // TODO sanity checks // TODO store keys - _bondManager().depositStETHWithPermit( + bondManager.depositStETHWithPermit( _from, _nodeOperatorId, - _bondManager().getRequiredBondStETH(_nodeOperatorId, _keysCount), + bondManager.getRequiredBondStETH(_nodeOperatorId, _keysCount), _permit ); @@ -473,13 +501,12 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { bytes calldata _publicKeys, bytes calldata _signatures ) external onlyExistingNodeOperator(_nodeOperatorId) { - // TODO sanity checks - // TODO store keys + // TODO: sanity checks - _bondManager().depositWstETH( + bondManager.depositWstETH( msg.sender, _nodeOperatorId, - _bondManager().getRequiredBondWstETH(_nodeOperatorId, _keysCount) + bondManager.getRequiredBondWstETH(_nodeOperatorId, _keysCount) ); _addSigningKeys(_nodeOperatorId, _keysCount, _publicKeys, _signatures); @@ -533,10 +560,10 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { // TODO sanity checks // TODO store keys - _bondManager().depositWstETHWithPermit( + bondManager.depositWstETHWithPermit( _from, _nodeOperatorId, - _bondManager().getRequiredBondWstETH(_nodeOperatorId, _keysCount), + bondManager.getRequiredBondWstETH(_nodeOperatorId, _keysCount), _permit ); @@ -587,7 +614,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { uint256 totalDepositedValidators ) { - NodeOperator memory no = nodeOperators[_nodeOperatorId]; + NodeOperator storage no = nodeOperators[_nodeOperatorId]; active = no.active; name = _fullInfo ? no.name : ""; rewardAddress = no.rewardAddress; @@ -660,22 +687,22 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { } function onRewardsMinted(uint256 /*_totalShares*/) external { - // TODO implement + // TODO: implement } function updateStuckValidatorsCount( bytes calldata /*_nodeOperatorIds*/, bytes calldata /*_stuckValidatorsCounts*/ ) external { - // TODO implement + // TODO: implement } function updateExitedValidatorsCount( bytes calldata _nodeOperatorIds, bytes calldata _exitedValidatorsCounts ) external { - // TODO implement - // emit ExitedKeysCountChanged( + // TODO: implement + // emit ExitedSigningKeysCountChanged( // _nodeOperatorId, // _exitedValidatorsCount // ); @@ -685,7 +712,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { uint256 /*_nodeOperatorId*/, uint256 /*_refundedValidatorsCount*/ ) external { - // TODO implement + // TODO: implement } function updateTargetValidatorsLimits( @@ -693,11 +720,11 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { bool /*_isTargetLimitActive*/, uint256 /*_targetLimit*/ ) external { - // TODO implement + // TODO: implement } function onExitedAndStuckValidatorsCountsUpdated() external { - // TODO implement + // TODO: implement } function unsafeUpdateValidatorsCount( @@ -705,7 +732,60 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { uint256 /*_exitedValidatorsKeysCount*/, uint256 /*_stuckValidatorsKeysCount*/ ) external { - // TODO implement + // TODO: implement + } + + function vetKeys( + uint256 nodeOperatorId, + uint64 vettedKeysCount + ) external onlyKeyValidator { + NodeOperator storage no = nodeOperators[nodeOperatorId]; + + require( + vettedKeysCount > no.totalVettedKeys, + "Wrong vettedKeysCount: less than already vetted" + ); + require( + vettedKeysCount <= no.totalAddedKeys, + "Wrong vettedKeysCount: more than added" + ); + + uint64 count = SafeCast.toUint64(vettedKeysCount - no.totalVettedKeys); + uint64 start = SafeCast.toUint64( + no.totalVettedKeys == 0 ? 0 : no.totalVettedKeys - 1 + ); + + bytes32 pointer = Batch.serialize({ + nodeOperatorId: SafeCast.toUint128(nodeOperatorId), + start: start, + count: count + }); + + no.totalVettedKeys = vettedKeysCount; + queue.enqueue(pointer); + + emit BatchEnqueued(nodeOperatorId, start, count); + emit VettedSigningKeysCountChanged(nodeOperatorId, vettedKeysCount); + + _incrementNonce(); + } + + function unvetKeys( + uint256 nodeOperatorId + ) external onlyKeyValidatorOrNodeOperatorManager { + _unvetKeys(nodeOperatorId); + bondManager.penalize(nodeOperatorId, unvettingFee); + } + + function unsafeUnvetKeys(uint256 nodeOperatorId) external onlyKeyValidator { + _unvetKeys(nodeOperatorId); + } + + function _unvetKeys(uint256 nodeOperatorId) internal { + NodeOperator storage no = nodeOperators[nodeOperatorId]; + no.totalVettedKeys = no.totalDepositedKeys; + emit VettedSigningKeysCountChanged(nodeOperatorId, no.totalVettedKeys); + _incrementNonce(); } function onWithdrawalCredentialsChanged() external { @@ -730,8 +810,9 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { _signatures ); + _totalAddedValidators += _keysCount; nodeOperators[_nodeOperatorId].totalAddedKeys += _keysCount; - emit TotalKeysCountChanged( + emit TotalSigningKeysCountChanged( _nodeOperatorId, nodeOperators[_nodeOperatorId].totalAddedKeys ); @@ -741,44 +822,190 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase { function obtainDepositData( uint256 _depositsCount, - bytes calldata /*_depositCalldata*/ + bytes calldata /* _depositCalldata */ ) external returns (bytes memory publicKeys, bytes memory signatures) { (publicKeys, signatures) = SigningKeys.initKeysSigsBuf(_depositsCount); + uint256 limit = _depositsCount; uint256 loadedKeysCount = 0; - for ( - uint256 nodeOperatorId; - nodeOperatorId < nodeOperatorsCount; - nodeOperatorId++ - ) { - NodeOperator storage no = nodeOperators[nodeOperatorId]; - // TODO replace total added to total vetted later - uint256 availableKeys = no.totalAddedKeys - no.totalDepositedKeys; - if (availableKeys == 0) continue; - - uint256 _startIndex = no.totalDepositedKeys; - uint256 _keysCount = _depositsCount > availableKeys - ? availableKeys - : _depositsCount; + + for (bytes32 p = queue.peek(); !Batch.isNil(p); ) { + ( + uint256 nodeOperatorId, + uint256 startIndex, + uint256 depositableKeysCount + ) = _depositableKeysInBatch(p); + + uint256 keysCount = Math.min(limit, depositableKeysCount); + if (depositableKeysCount == keysCount) { + queue.dequeue(); + } + SigningKeys.loadKeysSigs( SIGNING_KEYS_POSITION, nodeOperatorId, - _startIndex, - _keysCount, + startIndex, + keysCount, publicKeys, signatures, loadedKeysCount ); - loadedKeysCount += _keysCount; - // TODO maybe depositor bot should initiate this increment - no.totalDepositedKeys += _keysCount; - emit DepositedKeysCountChanged( + loadedKeysCount += keysCount; + + _totalDepositedValidators += keysCount; + NodeOperator storage no = nodeOperators[nodeOperatorId]; + no.totalDepositedKeys += keysCount; + require( + no.totalDepositedKeys <= no.totalVettedKeys, + "too many keys" + ); + + emit DepositedSigningKeysCountChanged( nodeOperatorId, no.totalDepositedKeys ); + + limit = limit - keysCount; + if (limit == 0) { + break; + } + + p = queue.peek(); + } + + require(loadedKeysCount == _depositsCount, "NOT_ENOUGH_KEYS"); + _incrementNonce(); + } + + function _depositableKeysInBatch( + bytes32 batch + ) + internal + view + returns ( + uint256 nodeOperatorId, + uint256 startIndex, + uint256 depositableKeysCount + ) + { + uint256 start; + uint256 count; + + (nodeOperatorId, start, count) = Batch.deserialize(batch); + + NodeOperator storage no = nodeOperators[nodeOperatorId]; + _assertIsValidBatch(no, start, count); + + startIndex = Math.max(start, no.totalDepositedKeys); + depositableKeysCount = start + count - startIndex; + } + + function _assertIsValidBatch( + NodeOperator storage no, + uint256 _start, + uint256 _count + ) internal view { + require(_count != 0, "Empty batch given"); + require( + _unvettedKeysInBatch(no, _start, _count) == false, + "Batch contains unvetted keys" + ); + require( + _start + _count <= no.totalAddedKeys, + "Invalid batch range: not enough keys" + ); + require( + _start <= no.totalDepositedKeys, + "Invalid batch range: skipped keys" + ); + } + + /// @dev returns the next pointer to start cleanup from + function cleanDepositQueue( + uint256 maxItems, + bytes32 pointer + ) external returns (bytes32) { + require(maxItems > 0, "Queue walkthrough limit is not set"); + + if (Batch.isNil(pointer)) { + pointer = queue.front; } - if (loadedKeysCount != _depositsCount) { - revert("NOT_ENOUGH_KEYS"); + + for (uint256 i; i < maxItems; i++) { + bytes32 item = queue.at(pointer); + if (Batch.isNil(item)) { + break; + } + + (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch + .deserialize(item); + NodeOperator storage no = nodeOperators[nodeOperatorId]; + if (_unvettedKeysInBatch(no, start, count)) { + queue.remove(pointer, item); + } + + pointer = item; } + + return pointer; + } + + function depositQueue( + uint256 maxItems, + bytes32 pointer + ) + external + view + returns ( + bytes32[] memory items, + bytes32 /* pointer */, + uint256 /* count */ + ) + { + require(maxItems > 0, "Queue walkthrough limit is not set"); + + if (Batch.isNil(pointer)) { + pointer = queue.front; + } + + return queue.list(pointer, maxItems); + } + + /// @dev returns the next pointer to start check from + function isQueueHasUnvettedKeys( + uint256 maxItems, + bytes32 pointer + ) external view returns (bool, bytes32) { + require(maxItems > 0, "Queue walkthrough limit is not set"); + + if (Batch.isNil(pointer)) { + pointer = queue.front; + } + + for (uint256 i; i < maxItems; i++) { + bytes32 item = queue.at(pointer); + if (Batch.isNil(item)) { + break; + } + + (uint256 nodeOperatorId, uint256 start, uint256 count) = Batch + .deserialize(item); + NodeOperator storage no = nodeOperators[nodeOperatorId]; + if (_unvettedKeysInBatch(no, start, count)) { + return (true, pointer); + } + + pointer = item; + } + + return (false, pointer); + } + + function _unvettedKeysInBatch( + NodeOperator storage no, + uint256 _start, + uint256 _count + ) internal view returns (bool) { + return _start + _count > no.totalVettedKeys; } function _incrementNonce() internal { diff --git a/src/interfaces/ICommunityStakingBondManager.sol b/src/interfaces/ICommunityStakingBondManager.sol index 1545488e..77893ad4 100644 --- a/src/interfaces/ICommunityStakingBondManager.sol +++ b/src/interfaces/ICommunityStakingBondManager.sol @@ -101,4 +101,6 @@ interface ICommunityStakingBondManager { uint256 nodeOperatorId, uint256 newKeysCount ) external view returns (uint256); + + function penalize(uint256 nodeOperatorId, uint256 shares) external; } diff --git a/src/lib/Batch.sol b/src/lib/Batch.sol new file mode 100644 index 00000000..664f88f7 --- /dev/null +++ b/src/lib/Batch.sol @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: 2023 Lido +// SPDX-License-Identifier: GPL-3.0 +pragma solidity 0.8.21; + +/// @author madlabman +library Batch { + /// @notice Serialize node operator id, batch start and count of keys into a single bytes32 value + function serialize( + uint128 nodeOperatorId, + uint64 start, + uint64 count + ) internal pure returns (bytes32 s) { + return bytes32(abi.encodePacked(nodeOperatorId, start, count)); + } + + /// @notice Deserialize node operator id, batch start and count of keys from a single bytes32 value + function deserialize( + bytes32 b + ) internal pure returns (uint128 nodeOperatorId, uint64 start, uint64 count) { + assembly { + nodeOperatorId := shr(128, b) + start := shr(64, b) + count := b + } + } + + function isNil(bytes32 b) internal pure returns (bool) { + return b == bytes32(0); + } +} diff --git a/src/lib/QueueLib.sol b/src/lib/QueueLib.sol new file mode 100644 index 00000000..0283d87f --- /dev/null +++ b/src/lib/QueueLib.sol @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: 2023 Lido +// SPDX-License-Identifier: GPL-3.0 +pragma solidity 0.8.21; + + +/// @author madlabman +library QueueLib { + bytes32 public constant NULL_POINTER = bytes32(0); + + struct Queue { + mapping(bytes32 => bytes32) queue; + bytes32 front; + bytes32 back; + } + + function enqueue(Queue storage self, bytes32 item) internal { + require(item != NULL_POINTER, "Queue: item is zero"); + require(self.queue[item] == NULL_POINTER, "Queue: item already enqueued"); + + if (self.front == self.queue[self.front]) { + self.queue[self.front] = item; + } + + self.queue[self.back] = item; + self.back = item; + } + + function dequeue(Queue storage self) internal notEmpty(self) returns (bytes32 item) { + item = self.queue[self.front]; + self.front = item; + } + + function peek(Queue storage self) internal view returns (bytes32) { + return self.queue[self.front]; + } + + function at(Queue storage self, bytes32 pointer) internal view returns (bytes32) { + return self.queue[pointer]; + } + + function list(Queue storage self, bytes32 pointer, uint256 limit) internal notEmpty(self) view returns ( + bytes32[] memory items, + bytes32 /* pointer */, + uint256 /* count */ + ) { + items = new bytes32[](limit); + + uint256 i; + for (; i < limit; i++) { + bytes32 item = self.queue[pointer]; + if (item == NULL_POINTER) { + break; + } + + items[i] = item; + pointer = item; + } + + return (items, pointer, i); + } + + function isEmpty(Queue storage self) internal view returns (bool) { + return self.front == self.back; + } + + function remove(Queue storage self, bytes32 pointerToItem, bytes32 item) internal { + require(self.queue[pointerToItem] == item, "Queue: wrong pointer given"); + + self.queue[pointerToItem] = self.queue[item]; + self.queue[item] = NULL_POINTER; + + if (self.back == item) { + self.back = pointerToItem; + } + } + + modifier notEmpty(Queue storage self) { + require(!isEmpty(self), "Queue: empty"); + _; + } +} diff --git a/test/Batch.t.sol b/test/Batch.t.sol new file mode 100644 index 00000000..a5a197ad --- /dev/null +++ b/test/Batch.t.sol @@ -0,0 +1,54 @@ +// SPDX-FileCopyrightText: 2023 Lido +// SPDX-License-Identifier: GPL-3.0 +pragma solidity 0.8.21; + +import "forge-std/Test.sol"; + +import { Batch } from "../src/lib/Batch.sol"; + +contract BatchTest is Test { + function test_serialize() public { + bytes32 b = Batch.serialize({ + nodeOperatorId: 999, + start: 3, + count: 42 + }); + + assertEq( + b, + // noIndex | start | count | + 0x000000000000000000000000000003e70000000000000003000000000000002a + ); + } + + function test_deserialize() public { + (uint128 nodeOperatorId, uint64 start, uint64 count) = Batch + .deserialize( + 0x0000000000000000000000000000000000000000000000000000000000000000 + ); + + assertEq(nodeOperatorId, 0, "nodeOperatorId != 0"); + assertEq(start, 0, "start != 0"); + assertEq(count, 0, "count != 0"); + + (nodeOperatorId, start, count) = Batch.deserialize( + 0x000000000000000000000000000003e70000000000000003000000000000002a + ); + + assertEq(nodeOperatorId, 999, "nodeOperatorId != 999"); + assertEq(start, 3, "start != 3"); + assertEq(count, 42, "count != 42"); + + (nodeOperatorId, start, count) = Batch.deserialize( + 0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff + ); + + assertEq( + nodeOperatorId, + type(uint128).max, + "nodeOperatorId != uint128.max" + ); + assertEq(start, type(uint64).max, "start != uint64.max"); + assertEq(count, type(uint64).max, "count != uint64.max"); + } +} diff --git a/test/CSMAddValidator.t.sol b/test/CSMAddValidator.t.sol index f4c0b2b7..6e5fa11b 100644 --- a/test/CSMAddValidator.t.sol +++ b/test/CSMAddValidator.t.sol @@ -90,7 +90,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 1); + emit TotalSigningKeysCountChanged(0, 1); vm.expectEmit(true, true, false, true, address(csm)); emit NodeOperatorAdded(0, "test", nodeOperator); } @@ -111,7 +111,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { vm.expectEmit(true, true, true, true, address(wstETH)); emit Approval(nodeOperator, address(bondManager), wstETHAmount); vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 1); + emit TotalSigningKeysCountChanged(0, 1); vm.expectEmit(true, true, false, true, address(csm)); emit NodeOperatorAdded(0, "test", nodeOperator); } @@ -145,7 +145,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { (bytes memory keys, bytes memory signatures) = keysSignatures(1, 1); { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 2); + emit TotalSigningKeysCountChanged(0, 2); } csm.addValidatorKeysWstETH(noId, 1, keys, signatures); } @@ -170,7 +170,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { vm.expectEmit(true, true, true, true, address(wstETH)); emit Approval(nodeOperator, address(bondManager), wstETHAmount); vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 2); + emit TotalSigningKeysCountChanged(0, 2); } vm.prank(stranger); csm.addValidatorKeysWstETHWithPermit( @@ -198,7 +198,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 1); + emit TotalSigningKeysCountChanged(0, 1); vm.expectEmit(true, true, false, true, address(csm)); emit NodeOperatorAdded(0, "test", nodeOperator); } @@ -218,7 +218,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { vm.expectEmit(true, true, true, true, address(stETH)); emit Approval(nodeOperator, address(bondManager), 2 ether); vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 1); + emit TotalSigningKeysCountChanged(0, 1); vm.expectEmit(true, true, false, true, address(csm)); emit NodeOperatorAdded(0, "test", nodeOperator); } @@ -252,7 +252,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { stETH.submit{ value: 2 ether }(address(0)); { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 2); + emit TotalSigningKeysCountChanged(0, 2); } csm.addValidatorKeysStETH(noId, 1, keys, signatures); } @@ -274,7 +274,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { vm.expectEmit(true, true, true, true, address(stETH)); emit Approval(nodeOperator, address(bondManager), required); vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 2); + emit TotalSigningKeysCountChanged(0, 2); } vm.prank(stranger); csm.addValidatorKeysStETHWithPermit( @@ -303,7 +303,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 1); + emit TotalSigningKeysCountChanged(0, 1); vm.expectEmit(true, true, false, true, address(csm)); emit NodeOperatorAdded(0, "test", nodeOperator); } @@ -328,7 +328,7 @@ contract CSMAddNodeOperator is CSMCommon, PermitTokenBase { vm.prank(nodeOperator); { vm.expectEmit(true, true, false, true, address(csm)); - emit TotalKeysCountChanged(0, 2); + emit TotalSigningKeysCountChanged(0, 2); } csm.addValidatorKeysETH{ value: required }(noId, 1, keys, signatures); } @@ -349,6 +349,12 @@ contract CSMObtainDepositData is CSMCommon { keys, signatures ); + + { + // Pretend to be a key validation oracle + csm.vetKeys(0, 1); + } + (bytes memory obtainedKeys, bytes memory obtainedSignatures) = csm .obtainDepositData(1, ""); assertEq(obtainedKeys, keys); diff --git a/test/CSMInit.t.sol b/test/CSMInit.t.sol index 0660f72c..f1653989 100644 --- a/test/CSMInit.t.sol +++ b/test/CSMInit.t.sol @@ -56,6 +56,6 @@ contract CSMInitTest is Test, Fixtures { function test_SetBondManager() public { csm.setBondManager(address(bondManager)); - assertEq(address(csm.bondManagerAddress()), address(bondManager)); + assertEq(address(csm.bondManager()), address(bondManager)); } } diff --git a/test/QueueLib.t.sol b/test/QueueLib.t.sol new file mode 100644 index 00000000..2f6acc66 --- /dev/null +++ b/test/QueueLib.t.sol @@ -0,0 +1,118 @@ +// SPDX-FileCopyrightText: 2023 Lido +// SPDX-License-Identifier: GPL-3.0 +pragma solidity 0.8.21; + +import "forge-std/Test.sol"; +import "forge-std/console.sol"; + +import { QueueLib } from "../src/lib/QueueLib.sol"; + +contract QueueLibTest is Test { + bytes32 p0 = keccak256("0x00"); // 0x27489e20a0060b723a1748bdff5e44570ee9fae64141728105692eac6031e8a4 + bytes32 p1 = keccak256("0x01"); // 0xe127292c8f7eb20e1ae830ed6055b6eb36e261836100610d12677231d0791f7f + bytes32 p2 = keccak256("0x02"); // 0xd3974deccfd8aa6b77f0fcc2c0014e6e0574d32e56c1d75717d2667b529cd073 + + bytes32 nil = bytes32(0); + bytes32 buf; + + using QueueLib for QueueLib.Queue; + QueueLib.Queue q; + + function test_enqueue() public { + assertEq(q.peek(), nil); + + q.enqueue(p0); + q.enqueue(p1); + + assertEq(q.peek(), p0); + assertEq(q.at(p0), p1); + } + + function test_dequeue() public { + assertTrue(q.isEmpty()); + + q.enqueue(p0); + q.enqueue(p1); + q.enqueue(p2); + + assertFalse(q.isEmpty()); + + buf = q.dequeue(); + assertEq(buf, p0); + assertEq(q.peek(), p1); + + buf = q.dequeue(); + assertEq(buf, p1); + assertEq(q.peek(), p2); + + q.dequeue(); + assertEq(q.peek(), nil); + assertTrue(q.isEmpty()); + } + + function test_list() public { + q.enqueue(p0); + q.enqueue(p1); + q.enqueue(p2); + + { + (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( + q.front, + 2 + ); + assertEq(count, 2); + assertEq(pointer, p1); + assertEq(items[0], p0); + assertEq(items[1], p1); + } + + { + (bytes32[] memory items, bytes32 pointer, uint256 count) = q.list( + p1, + 999 + ); + assertEq(count, 1); + assertEq(pointer, p2); + assertEq(items[0], p2); + } + + q.dequeue(); + + { + (, bytes32 pointer, uint256 count) = q.list(q.front, 0); + assertEq(count, 0); + assertEq(pointer, q.front); + } + } + + function test_remove() public { + q.enqueue(p0); + q.enqueue(p1); + q.enqueue(p2); + // [+*p0, p1, p2] + + q.remove(p0, p1); + // [+*p0, p2] + + q.dequeue(); + // [+p0, *p2] + buf = q.dequeue(); + // [p0, +*p2] + assertEq(buf, p2); + + q.enqueue(p1); + // [p0, +p2, *p1] + assertEq(q.peek(), p1); + + q.remove(p2, p1); + // [p0, +*p2] + assertEq(q.peek(), nil); + assertTrue(q.isEmpty()); + + q.remove(p0, p2); + // [+*p0] + assertEq(q.peek(), nil); + } + + // TODO: test with revert on library call +} diff --git a/test/integration/StakingRouter.t.sol b/test/integration/StakingRouter.t.sol index 0bccced4..5b5b5b2b 100644 --- a/test/integration/StakingRouter.t.sol +++ b/test/integration/StakingRouter.t.sol @@ -101,18 +101,23 @@ contract StakingRouterIntegrationTest is Test, Utilities { _treasuryFee: 500 }); uint256[] memory ids = stakingRouter.getStakingModuleIds(); - (bytes memory keys, bytes memory signatures) = keysSignatures(1); + (bytes memory keys, bytes memory signatures) = keysSignatures(2); address nodeOperator = address(2); - vm.deal(nodeOperator, 2 ether); + vm.deal(nodeOperator, 4 ether); vm.prank(nodeOperator); - csm.addNodeOperatorETH{ value: 2 ether }( + csm.addNodeOperatorETH{ value: 4 ether }( "test", nodeOperator, - 1, + 2, keys, signatures ); + { + // Pretend to be a key validation oracle + csm.vetKeys(0, 2); + } + // It's impossible to process deposits if withdrawal requests amount is more than the buffered ether, // so we need to make sure that the buffered ether is enough by submitting this tremendous amount. address whale = nextAddress();