Skip to content

Commit

Permalink
refactor: make Queue contract a library
Browse files Browse the repository at this point in the history
  • Loading branch information
madlabman committed Oct 9, 2023
1 parent 5eaa51f commit c532b50
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 121 deletions.
3 changes: 1 addition & 2 deletions script/Deploy.s.sol
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ contract Deploy is Script {
wstETH = IWstETH(WSTETH_ADDRESS);
CommunityStakingModule csm = new CommunityStakingModule(
"community-staking-module",
address(locator),
address(90210) // FIXME
address(locator)
);
CommunityStakingBondManager bondManager = new CommunityStakingBondManager({
_commonBondSize: 2 ether,
Expand Down
40 changes: 17 additions & 23 deletions src/CommunityStakingModule.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import { Math } from "@openzeppelin/contracts/utils/math/Math.sol";
import { ICommunityStakingBondManager } from "./interfaces/ICommunityStakingBondManager.sol";
import { IStakingModule } from "./interfaces/IStakingModule.sol";
import { ILidoLocator } from "./interfaces/ILidoLocator.sol";
import { IQueue } from "./interfaces/IQueue.sol";
import { ILido } from "./interfaces/ILido.sol";

import { QueueLib } from "./lib/QueueLib.sol";
import { Batch } from "./lib/Batch.sol";

import "./lib/SigningKeys.sol";
Expand Down Expand Up @@ -59,6 +59,8 @@ contract CommunityStakingModuleBase {
}

contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
using QueueLib for QueueLib.Queue;

uint256 private nodeOperatorsCount;
uint256 private activeNodeOperatorsCount;
bytes32 private moduleType;
Expand All @@ -68,24 +70,22 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
bytes32 public constant SIGNING_KEYS_POSITION =
keccak256("lido.CommunityStakingModule.signingKeysPosition");

QueueLib.Queue public queue;

address public bondManagerAddress;
address public lidoLocator;
address public queue;

event VettedSigningKeysCountChanged(
uint256 indexed nodeOperatorId,
uint256 approvedValidatorsCount
);

constructor(bytes32 _type, address _locator, address _queue) {
constructor(bytes32 _type, address _locator) {
moduleType = _type;
nodeOperatorsCount = 0;

require(_locator != address(0), "lido locator is zero address");
lidoLocator = _locator;

require(_queue != address(0), "Queue address is zero address");
queue = _queue;
}

function setBondManager(address _bondManagerAddress) external {
Expand Down Expand Up @@ -472,17 +472,14 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
});

no.totalVettedKeys = _vettedKeysCount;
IQueue(queue).enqueue(pointer);
queue.enqueue(pointer);
emit VettedSigningKeysCountChanged(_nodeOperatorId, _vettedKeysCount);
}

function unvetKeys(uint64 _nodeOperatorId) external {
NodeOperator storage no = nodeOperators[_nodeOperatorId];
no.totalVettedKeys = no.totalDepositedKeys;
emit VettedSigningKeysCountChanged(
_nodeOperatorId,
no.totalVettedKeys
);
emit VettedSigningKeysCountChanged(_nodeOperatorId, no.totalVettedKeys);
}

function onWithdrawalCredentialsChanged() external {
Expand Down Expand Up @@ -522,7 +519,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
) external returns (bytes memory publicKeys, bytes memory signatures) {
uint256 limit = _depositsCount;

for (bytes32 p = IQueue(queue).front(); !Batch.isNil(p); ) {
for (bytes32 p = queue.peek(); !Batch.isNil(p); ) {
(uint256 nodeOperatorId, uint256 start, uint256 end) = Batch
.deserialize(p);

Expand All @@ -541,7 +538,7 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
break;
}

p = IQueue(queue).front();
p = queue.peek();
}
}

Expand All @@ -565,18 +562,15 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
require(_end < no.totalVettedKeys, "NO was unvetted");
require(_end < no.totalAddedKeys, "not enough keys");

require(
no.totalDepositedKeys >= _start,
"invalid range: skipped keys"
);
require(no.totalDepositedKeys >= _start, "invalid range: skipped keys");

uint256 _startIndex = Math.max(_start, no.totalDepositedKeys);
uint256 _endIndex = Math.min(_end, _startIndex + limit);
count = _endIndex - _startIndex + 1;

no.totalDepositedKeys = _endIndex + 1;
if (_end == _endIndex) {
IQueue(queue).dequeue();
queue.dequeue();
}

SigningKeys.loadKeysSigs(
Expand All @@ -596,18 +590,18 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
bytes32 pointer
) external returns (bytes32) {
if (Batch.isNil(pointer)) {
pointer = IQueue(queue).frontPointer();
pointer = queue.front;
}

for (uint256 i; i < maxItems; i++) {
bytes32 item = IQueue(queue).at(pointer);
bytes32 item = queue.at(pointer);
if (Batch.isNil(item)) {
break;
}

(uint256 nodeOperatorId, , uint256 end) = Batch.deserialize(item);
if (_unvettedKeysInBatch(nodeOperatorId, end)) {
IQueue(queue).remove(pointer, item);
queue.remove(pointer, item);
}

pointer = item;
Expand All @@ -622,11 +616,11 @@ contract CommunityStakingModule is IStakingModule, CommunityStakingModuleBase {
bytes32 pointer
) external view returns (bool, bytes32) {
if (Batch.isNil(pointer)) {
pointer = IQueue(queue).frontPointer();
pointer = queue.front;
}

for (uint256 i; i < maxItems; i++) {
bytes32 item = IQueue(queue).at(pointer);
bytes32 item = queue.at(pointer);
if (Batch.isNil(item)) {
break;
}
Expand Down
58 changes: 0 additions & 58 deletions src/Queue.sol

This file was deleted.

17 changes: 0 additions & 17 deletions src/interfaces/IQueue.sol

This file was deleted.

56 changes: 56 additions & 0 deletions src/lib/QueueLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// SPDX-FileCopyrightText: 2023 Lido <[email protected]>
// SPDX-License-Identifier: GPL-3.0
pragma solidity 0.8.21;


/// @author madlabman
library QueueLib {
bytes32 public constant NULL_POINTER = bytes32(0);

struct Queue {
mapping(bytes32 => bytes32) queue;
bytes32 front;
bytes32 back;
}

function enqueue(Queue storage self, bytes32 item) internal {
require(item != NULL_POINTER, "Queue: item is zero");
require(self.queue[item] == NULL_POINTER, "Queue: item already enqueued");

if (self.front == self.queue[self.front]) {
self.queue[self.front] = item;
}

self.queue[self.back] = item;
self.back = item;
}

function dequeue(Queue storage self) internal notEmpty(self) returns (bytes32 item) {
item = self.queue[self.front];
self.front = item;
}

function peek(Queue storage self) internal view returns (bytes32) {
return self.queue[self.front];
}

function at(Queue storage self, bytes32 pointer) internal view returns (bytes32) {
return self.queue[pointer];
}

function remove(Queue storage self, bytes32 pointerToItem, bytes32 item) internal {
require(self.queue[pointerToItem] == item, "Queue: wrong pointer given");

self.queue[pointerToItem] = self.queue[item];
self.queue[item] = NULL_POINTER;

if (self.back == item) {
self.back = pointerToItem;
}
}

modifier notEmpty(Queue storage self) {
require(self.front != self.back, "Queue: empty");
_;
}
}
3 changes: 1 addition & 2 deletions test/CSMAddValidator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ contract CSMAddNodeOperator is
);
csm = new CommunityStakingModule(
"community-staking-module",
address(locator),
address(90210) // FIXME
address(locator)
);
bondManager = new CommunityStakingBondManager(
2 ether,
Expand Down
3 changes: 1 addition & 2 deletions test/CSMInit.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ contract CSMInitTest is Test, Fixtures {

csm = new CommunityStakingModule(
"community-staking-module",
address(locator),
address(90210) // FIXME
address(locator)
);
communityStakingFeeDistributor = new CommunityStakingFeeDistributorMock(
address(locator),
Expand Down
27 changes: 12 additions & 15 deletions test/Queue.t.sol → test/QueueLib.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,26 @@ pragma solidity 0.8.21;
import "forge-std/Test.sol";
import "forge-std/console.sol";

import { Queue } from "../src/Queue.sol";
import { QueueLib } from "../src/lib/QueueLib.sol";

contract QueueTest is Test {
bytes32 p0 = keccak256("0x00"); // 0x27489e20a0060b723a1748bdff5e44570ee9fae64141728105692eac6031e8a4
bytes32 p1 = keccak256("0x01"); // 0xe127292c8f7eb20e1ae830ed6055b6eb36e261836100610d12677231d0791f7f
bytes32 p2 = keccak256("0x02"); // 0xd3974deccfd8aa6b77f0fcc2c0014e6e0574d32e56c1d75717d2667b529cd073

bytes32 nil = bytes32(0);
bytes32 buf;
bytes32 nil;
Queue q;

function setUp() public {
q = new Queue();
nil = q.NULL_POINTER();
}
using QueueLib for QueueLib.Queue;
QueueLib.Queue q;

function test_enqueue() public {
assertEq(q.front(), nil);
assertEq(q.peek(), nil);

q.enqueue(p0);
q.enqueue(p1);

assertEq(q.front(), p0);
assertEq(q.peek(), p0);
assertEq(q.at(p0), p1);
}

Expand All @@ -43,14 +40,14 @@ contract QueueTest is Test {

buf = q.dequeue();
assertEq(buf, p0);
assertEq(q.front(), p1);
assertEq(q.peek(), p1);

buf = q.dequeue();
assertEq(buf, p1);
assertEq(q.front(), p2);
assertEq(q.peek(), p2);

q.dequeue();
assertEq(q.front(), nil);
assertEq(q.peek(), nil);

{
vm.expectRevert("Queue: empty");
Expand All @@ -75,18 +72,18 @@ contract QueueTest is Test {

q.enqueue(p1);
// [p0, +p2, *p1]
assertEq(q.front(), p1);
assertEq(q.peek(), p1);

q.remove(p2, p1);
// [p0, +*p2]
assertEq(q.front(), nil);
assertEq(q.peek(), nil);
{
vm.expectRevert("Queue: empty");
q.dequeue();
}

q.remove(p0, p2);
// [+*p0]
assertEq(q.front(), nil);
assertEq(q.peek(), nil);
}
}
Loading

0 comments on commit c532b50

Please sign in to comment.