From 8fcfe96ce5908fa6cfdb93e18563402495e5998e Mon Sep 17 00:00:00 2001 From: vgorkavenko Date: Wed, 30 Oct 2024 14:12:54 +0100 Subject: [PATCH] feat: per frame data --- src/modules/csm/checkpoint.py | 5 +- src/modules/csm/csm.py | 73 +++++++--- src/modules/csm/log.py | 6 +- src/modules/csm/state.py | 145 ++++++++++++++------ tests/modules/csm/test_checkpoint.py | 14 +- tests/modules/csm/test_csm_module.py | 196 ++++++++++++++++++++++++--- tests/modules/csm/test_log.py | 61 +++++++-- tests/modules/csm/test_state.py | 149 ++++++++++++-------- 8 files changed, 496 insertions(+), 153 deletions(-) diff --git a/src/modules/csm/checkpoint.py b/src/modules/csm/checkpoint.py index 0efc326c6..b111fe197 100644 --- a/src/modules/csm/checkpoint.py +++ b/src/modules/csm/checkpoint.py @@ -143,6 +143,7 @@ def exec(self, checkpoint: FrameCheckpoint) -> int: for duty_epoch in unprocessed_epochs } self._process(unprocessed_epochs, duty_epochs_roots) + self.state.commit() return len(unprocessed_epochs) def _get_block_roots(self, checkpoint_slot: SlotNumber): @@ -208,14 +209,14 @@ def _check_duty( with lock: for committee in committees.values(): for validator_duty in committee: - self.state.inc( + self.state.increment_duty( + duty_epoch, validator_duty.index, included=validator_duty.included, ) if duty_epoch not in self.state.unprocessed_epochs: raise ValueError(f"Epoch {duty_epoch} is not in epochs that should be processed") self.state.add_processed_epoch(duty_epoch) - self.state.commit() self.state.log_progress() unprocessed_epochs = self.state.unprocessed_epochs CSM_UNPROCESSED_EPOCHS_COUNT.set(len(unprocessed_epochs)) diff --git a/src/modules/csm/csm.py b/src/modules/csm/csm.py index 543a276e2..3d3619a94 100644 --- a/src/modules/csm/csm.py +++ b/src/modules/csm/csm.py @@ -13,7 +13,7 @@ from src.metrics.prometheus.duration_meter import duration_meter from src.modules.csm.checkpoint import FrameCheckpointProcessor, FrameCheckpointsIterator, MinStepIsNotReached from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import State +from src.modules.csm.state import State, Frame from src.modules.csm.tree import Tree from src.modules.csm.types import ReportData, Shares from src.modules.submodules.consensus import ConsensusModule @@ -29,10 +29,11 @@ SlotNumber, StakingModuleAddress, StakingModuleId, + ValidatorIndex, ) from src.utils.blockstamp import build_blockstamp from src.utils.cache import global_lru_cache as lru_cache -from src.utils.slot import get_next_non_missed_slot +from src.utils.slot import get_next_non_missed_slot, get_reference_blockstamp from src.utils.web3converter import Web3Converter from src.web3py.extensions.lido_validators import NodeOperatorId, StakingModule, ValidatorsByNodeOperator from src.web3py.types import Web3 @@ -101,12 +102,12 @@ def build_report(self, blockstamp: ReferenceBlockStamp) -> tuple: if (prev_cid is None) != (prev_root == ZERO_HASH): raise InconsistentData(f"Got inconsistent previous tree data: {prev_root=} {prev_cid=}") - distributed, shares, log = self.calculate_distribution(blockstamp) + distributed, shares, logs = self.calculate_distribution(blockstamp) if distributed != sum(shares.values()): raise InconsistentData(f"Invalid distribution: {sum(shares.values())=} != {distributed=}") - log_cid = self.publish_log(log) + log_cid = self.publish_log(logs) if not distributed and not shares: logger.info({"msg": "No shares distributed in the current frame"}) @@ -201,7 +202,7 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: logger.info({"msg": "The starting epoch of the frame is not finalized yet"}) return False - self.state.migrate(l_epoch, r_epoch, consensus_version) + self.state.init_or_migrate(l_epoch, r_epoch, converter.frame_config.epochs_per_frame, consensus_version) self.state.log_progress() if self.state.is_fulfilled: @@ -227,17 +228,56 @@ def collect_data(self, blockstamp: BlockStamp) -> bool: def calculate_distribution( self, blockstamp: ReferenceBlockStamp - ) -> tuple[int, defaultdict[NodeOperatorId, int], FramePerfLog]: + ) -> tuple[int, defaultdict[NodeOperatorId, int], list[FramePerfLog]]: """Computes distribution of fee shares at the given timestamp""" - - network_avg_perf = self.state.get_network_aggr().perf - threshold = network_avg_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS operators_to_validators = self.module_validators_by_node_operators(blockstamp) + distributed = 0 + # Calculate share of each CSM node operator. + shares = defaultdict[NodeOperatorId, int](int) + logs: list[FramePerfLog] = [] + + for frame in self.state.data: + from_epoch, to_epoch = frame + logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"}) + frame_blockstamp = blockstamp + if to_epoch != blockstamp.ref_epoch: + frame_blockstamp = self._get_ref_blockstamp_for_frame(blockstamp, to_epoch) + distributed_in_frame, shares_in_frame, log = self._calculate_distribution_in_frame( + frame_blockstamp, operators_to_validators, frame, distributed + ) + distributed += distributed_in_frame + for no_id, share in shares_in_frame.items(): + shares[no_id] += share + logs.append(log) + + return distributed, shares, logs + + def _get_ref_blockstamp_for_frame( + self, blockstamp: ReferenceBlockStamp, frame_ref_epoch: EpochNumber + ) -> ReferenceBlockStamp: + converter = self.converter(blockstamp) + return get_reference_blockstamp( + cc=self.w3.cc, + ref_slot=converter.get_epoch_last_slot(frame_ref_epoch), + ref_epoch=frame_ref_epoch, + last_finalized_slot_number=blockstamp.slot_number, + ) + + def _calculate_distribution_in_frame( + self, + blockstamp: ReferenceBlockStamp, + operators_to_validators: ValidatorsByNodeOperator, + frame: Frame, + distributed: int, + ): + network_perf = self.state.get_network_aggr(frame).perf + threshold = network_perf - self.w3.csm.oracle.perf_leeway_bp(blockstamp.block_hash) / TOTAL_BASIS_POINTS + # Build the map of the current distribution operators. distribution: dict[NodeOperatorId, int] = defaultdict(int) stuck_operators = self.stuck_operators(blockstamp) - log = FramePerfLog(blockstamp, self.state.frame, threshold) + log = FramePerfLog(blockstamp, frame, threshold) for (_, no_id), validators in operators_to_validators.items(): if no_id in stuck_operators: @@ -245,7 +285,7 @@ def calculate_distribution( continue for v in validators: - aggr = self.state.data.get(v.index) + aggr = self.state.data[frame].get(ValidatorIndex(int(v.index))) if aggr is None: # It's possible that the validator is not assigned to any duty, hence it's performance @@ -268,13 +308,12 @@ def calculate_distribution( # Calculate share of each CSM node operator. shares = defaultdict[NodeOperatorId, int](int) total = sum(p for p in distribution.values()) + to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - distributed + log.distributable = to_distribute if not total: return 0, shares, log - to_distribute = self.w3.csm.fee_distributor.shares_to_distribute(blockstamp.block_hash) - log.distributable = to_distribute - for no_id, no_share in distribution.items(): if no_share: shares[no_id] = to_distribute * no_share // total @@ -348,9 +387,9 @@ def publish_tree(self, tree: Tree) -> CID: logger.info({"msg": "Tree dump uploaded to IPFS", "cid": repr(tree_cid)}) return tree_cid - def publish_log(self, log: FramePerfLog) -> CID: - log_cid = self.w3.ipfs.publish(log.encode()) - logger.info({"msg": "Frame log uploaded to IPFS", "cid": repr(log_cid)}) + def publish_log(self, logs: list[FramePerfLog]) -> CID: + log_cid = self.w3.ipfs.publish(FramePerfLog.encode(logs)) + logger.info({"msg": "Frame(s) log uploaded to IPFS", "cid": repr(log_cid)}) return log_cid @lru_cache(maxsize=1) diff --git a/src/modules/csm/log.py b/src/modules/csm/log.py index f89f4ef58..39832c8c0 100644 --- a/src/modules/csm/log.py +++ b/src/modules/csm/log.py @@ -12,6 +12,7 @@ class LogJSONEncoder(json.JSONEncoder): ... @dataclass class ValidatorFrameSummary: + # TODO: Should be renamed. Perf means different things in different contexts perf: AttestationsAccumulator = field(default_factory=AttestationsAccumulator) slashed: bool = False @@ -35,13 +36,14 @@ class FramePerfLog: default_factory=lambda: defaultdict(OperatorFrameSummary) ) - def encode(self) -> bytes: + @staticmethod + def encode(logs: list['FramePerfLog']) -> bytes: return ( LogJSONEncoder( indent=None, separators=(',', ':'), sort_keys=True, ) - .encode(asdict(self)) + .encode([asdict(log) for log in logs]) .encode() ) diff --git a/src/modules/csm/state.py b/src/modules/csm/state.py index 4373f5259..fd27a8d62 100644 --- a/src/modules/csm/state.py +++ b/src/modules/csm/state.py @@ -3,6 +3,7 @@ import pickle from collections import defaultdict from dataclasses import dataclass +from itertools import batched from pathlib import Path from typing import Self @@ -12,6 +13,8 @@ logger = logging.getLogger(__name__) +type Frame = tuple[EpochNumber, EpochNumber] + class InvalidState(ValueError): """State has data considered as invalid for a report""" @@ -43,18 +46,21 @@ class State: The state can be migrated to be used for another frame's report by calling the `migrate` method. """ - - data: defaultdict[ValidatorIndex, AttestationsAccumulator] + data: dict[Frame, defaultdict[ValidatorIndex, AttestationsAccumulator]] _epochs_to_process: tuple[EpochNumber, ...] _processed_epochs: set[EpochNumber] + _epochs_per_frame: int _consensus_version: int = 1 - def __init__(self, data: dict[ValidatorIndex, AttestationsAccumulator] | None = None) -> None: - self.data = defaultdict(AttestationsAccumulator, data or {}) + def __init__(self, data: dict[Frame, dict[ValidatorIndex, AttestationsAccumulator]] | None = None) -> None: + self.data = { + frame: defaultdict(AttestationsAccumulator, validators) for frame, validators in (data or {}).items() + } self._epochs_to_process = tuple() self._processed_epochs = set() + self._epochs_per_frame = 0 EXTENSION = ".pkl" @@ -89,14 +95,37 @@ def file(cls) -> Path: def buffer(self) -> Path: return self.file().with_suffix(".buf") + @property + def is_empty(self) -> bool: + return not self.data and not self._epochs_to_process and not self._processed_epochs + + @property + def unprocessed_epochs(self) -> set[EpochNumber]: + if not self._epochs_to_process: + raise ValueError("Epochs to process are not set") + diff = set(self._epochs_to_process) - self._processed_epochs + return diff + + @property + def is_fulfilled(self) -> bool: + return not self.unprocessed_epochs + def clear(self) -> None: - self.data = defaultdict(AttestationsAccumulator) + self.data = {} self._epochs_to_process = tuple() self._processed_epochs.clear() assert self.is_empty - def inc(self, key: ValidatorIndex, included: bool) -> None: - self.data[key].add_duty(included) + def find_frame(self, epoch: EpochNumber) -> Frame: + frames = self.data.keys() + for epoch_range in frames: + if epoch_range[0] <= epoch <= epoch_range[1]: + return epoch_range + raise ValueError(f"Epoch {epoch} is out of frames range: {frames}") + + def increment_duty(self, epoch: EpochNumber, val_index: ValidatorIndex, included: bool) -> None: + epoch_range = self.find_frame(epoch) + self.data[epoch_range][val_index].add_duty(included) def add_processed_epoch(self, epoch: EpochNumber) -> None: self._processed_epochs.add(epoch) @@ -104,7 +133,7 @@ def add_processed_epoch(self, epoch: EpochNumber) -> None: def log_progress(self) -> None: logger.info({"msg": f"Processed {len(self._processed_epochs)} of {len(self._epochs_to_process)} epochs"}) - def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: int): + def init_or_migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int, consensus_version: int) -> None: if consensus_version != self._consensus_version: logger.warning( { @@ -114,17 +143,60 @@ def migrate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, consensus_version: ) self.clear() - for state_epochs in (self._epochs_to_process, self._processed_epochs): - for epoch in state_epochs: - if epoch < l_epoch or epoch > r_epoch: - logger.warning({"msg": "Discarding invalidated state cache"}) - self.clear() - break + if not self.is_empty: + invalidated = self._migrate_or_invalidate(l_epoch, r_epoch, epochs_per_frame) + if invalidated: + self.clear() + self._fill_frames(l_epoch, r_epoch, epochs_per_frame) + self._epochs_per_frame = epochs_per_frame self._epochs_to_process = tuple(sequence(l_epoch, r_epoch)) self._consensus_version = consensus_version self.commit() + def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None: + frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + for frame in frames: + self.data.setdefault(frame, defaultdict(AttestationsAccumulator)) + + def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool: + current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame) + new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame) + inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}" + + if self._invalidate_on_epoch_range_change(l_epoch, r_epoch): + logger.warning({"msg": inv_msg}) + return True + + frame_expanded = epochs_per_frame > self._epochs_per_frame + frame_shrunk = epochs_per_frame < self._epochs_per_frame + + has_single_frame = len(current_frames) == len(new_frames) == 1 + + if has_single_frame and frame_expanded: + current_frame, *_ = current_frames + new_frame, *_ = new_frames + self.data[new_frame] = self.data.pop(current_frame) + logger.info({"msg": f"Migrated state cache to a new frame. {current_frame=}, {new_frame=}"}) + return False + + if has_single_frame and frame_shrunk: + logger.warning({"msg": inv_msg}) + return True + + if not has_single_frame and frame_expanded or frame_shrunk: + logger.warning({"msg": inv_msg}) + return True + + return False + + def _invalidate_on_epoch_range_change(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> bool: + """Check if the epoch range has been invalidated.""" + for epoch_set in (self._epochs_to_process, self._processed_epochs): + if any(epoch < l_epoch or epoch > r_epoch for epoch in epoch_set): + return True + return False + def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: if not self.is_fulfilled: raise InvalidState(f"State is not fulfilled. {self.unprocessed_epochs=}") @@ -135,34 +207,25 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None: for epoch in sequence(l_epoch, r_epoch): if epoch not in self._processed_epochs: - raise InvalidState(f"Epoch {epoch} should be processed") - - @property - def is_empty(self) -> bool: - return not self.data and not self._epochs_to_process and not self._processed_epochs - - @property - def unprocessed_epochs(self) -> set[EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - diff = set(self._epochs_to_process) - self._processed_epochs - return diff - - @property - def is_fulfilled(self) -> bool: - return not self.unprocessed_epochs - - @property - def frame(self) -> tuple[EpochNumber, EpochNumber]: - if not self._epochs_to_process: - raise ValueError("Epochs to process are not set") - return min(self._epochs_to_process), max(self._epochs_to_process) - - def get_network_aggr(self) -> AttestationsAccumulator: - """Return `AttestationsAccumulator` over duties of all the network validators""" - + raise InvalidState(f"Epoch {epoch} missing in processed epochs") + + @staticmethod + def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]: + """Split epochs to process into frames of `epochs_per_frame` length""" + frames = [] + for frame_epochs in batched(epochs_to_process, epochs_per_frame): + if len(frame_epochs) < epochs_per_frame: + raise ValueError("Insufficient epochs to form a frame") + frames.append((frame_epochs[0], frame_epochs[-1])) + return frames + + def get_network_aggr(self, frame: Frame) -> AttestationsAccumulator: + # TODO: exclude `active_slashed` validators from the calculation included = assigned = 0 - for validator, acc in self.data.items(): + frame_data = self.data.get(frame) + if not frame_data: + raise ValueError(f"No data for frame {frame} to calculate network aggregate") + for validator, acc in frame_data.items(): if acc.included > acc.assigned: raise ValueError(f"Invalid accumulator: {validator=}, {acc=}") included += acc.included diff --git a/tests/modules/csm/test_checkpoint.py b/tests/modules/csm/test_checkpoint.py index 44f23735e..4b456ed03 100644 --- a/tests/modules/csm/test_checkpoint.py +++ b/tests/modules/csm/test_checkpoint.py @@ -326,7 +326,7 @@ def test_checkpoints_processor_no_eip7549_support( monkeypatch: pytest.MonkeyPatch, ): state = State() - state.migrate(EpochNumber(0), EpochNumber(255), 1) + state.init_or_migrate(EpochNumber(0), EpochNumber(255), 256, 1) processor = FrameCheckpointProcessor( consensus_client, state, @@ -354,7 +354,7 @@ def test_checkpoints_processor_check_duty( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -367,7 +367,7 @@ def test_checkpoints_processor_check_duty( assert len(state._processed_epochs) == 1 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 255 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_process( @@ -379,7 +379,7 @@ def test_checkpoints_processor_process( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -392,7 +392,7 @@ def test_checkpoints_processor_process( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 def test_checkpoints_processor_exec( @@ -404,7 +404,7 @@ def test_checkpoints_processor_exec( converter, ): state = State() - state.migrate(0, 255, 1) + state.init_or_migrate(0, 255, 256, 1) finalized_blockstamp = ... processor = FrameCheckpointProcessor( consensus_client, @@ -418,4 +418,4 @@ def test_checkpoints_processor_exec( assert len(state._processed_epochs) == 2 assert len(state._epochs_to_process) == 256 assert len(state.unprocessed_epochs) == 254 - assert len(state.data) == 2048 * 32 + assert len(state.data[(0, 255)]) == 2048 * 32 diff --git a/tests/modules/csm/test_csm_module.py b/tests/modules/csm/test_csm_module.py index f74af8d69..cdb0c92c5 100644 --- a/tests/modules/csm/test_csm_module.py +++ b/tests/modules/csm/test_csm_module.py @@ -9,7 +9,7 @@ from src.constants import UINT64_MAX from src.modules.csm.csm import CSOracle -from src.modules.csm.state import AttestationsAccumulator, State +from src.modules.csm.state import AttestationsAccumulator, State, Frame from src.modules.csm.tree import Tree from src.modules.submodules.oracle_module import ModuleExecuteDelay from src.modules.submodules.types import CurrentFrame, ZERO_HASH @@ -166,26 +166,37 @@ def test_calculate_distribution(module: CSOracle, csm: CSM): ] ) + frame_0: Frame = (EpochNumber(0), EpochNumber(999)) + + module.state.init_or_migrate(*frame_0, epochs_per_frame=1000, consensus_version=1) module.state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame - ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), - ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming - ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming - ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), - ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming - # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state - ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), - ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + } } ) - module.state.migrate(EpochNumber(100), EpochNumber(500), 1) - _, shares, log = module.calculate_distribution(blockstamp=Mock()) + l_epoch, r_epoch = frame_0 + + frame_0_network_aggr = module.state.get_network_aggr(frame_0) + + blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) + _, shares, logs = module.calculate_distribution(blockstamp=blockstamp) + + log, *_ = logs assert tuple(shares.items()) == ( (NodeOperatorId(0), 476), @@ -225,8 +236,157 @@ def test_calculate_distribution(module: CSOracle, csm: CSM): assert log.operators[NodeOperatorId(3)].distributed == 2380 assert log.operators[NodeOperatorId(6)].distributed == 2380 - assert log.frame == (100, 500) - assert log.threshold == module.state.get_network_aggr().perf - 0.05 + assert log.frame == frame_0 + assert log.threshold == frame_0_network_aggr.perf - 0.05 + + +def test_calculate_distribution_with_missed_with_two_frames(module: CSOracle, csm: CSM): + csm.oracle.perf_leeway_bp = Mock(return_value=500) + csm.fee_distributor.shares_to_distribute = Mock(side_effect=[10000, 20000]) + + module.module_validators_by_node_operators = Mock( + return_value={ + (None, NodeOperatorId(0)): [Mock(index=0, validator=Mock(slashed=False))], + (None, NodeOperatorId(1)): [Mock(index=1, validator=Mock(slashed=False))], + (None, NodeOperatorId(2)): [Mock(index=2, validator=Mock(slashed=False))], # stuck + (None, NodeOperatorId(3)): [Mock(index=3, validator=Mock(slashed=False))], + (None, NodeOperatorId(4)): [Mock(index=4, validator=Mock(slashed=False))], # stuck + (None, NodeOperatorId(5)): [ + Mock(index=5, validator=Mock(slashed=False)), + Mock(index=6, validator=Mock(slashed=False)), + ], + (None, NodeOperatorId(6)): [ + Mock(index=7, validator=Mock(slashed=False)), + Mock(index=8, validator=Mock(slashed=False)), + ], + (None, NodeOperatorId(7)): [Mock(index=9, validator=Mock(slashed=False))], + (None, NodeOperatorId(8)): [ + Mock(index=10, validator=Mock(slashed=False)), + Mock(index=11, validator=Mock(slashed=True)), + ], + (None, NodeOperatorId(9)): [Mock(index=12, validator=Mock(slashed=True))], + } + ) + + module.stuck_operators = Mock( + side_effect=[ + [ + NodeOperatorId(2), + NodeOperatorId(4), + ], + [ + NodeOperatorId(2), + NodeOperatorId(4), + ], + ] + ) + + module.state = State() + l_epoch, r_epoch = EpochNumber(0), EpochNumber(1999) + frame_0 = (0, 999) + frame_1 = (1000, 1999) + module.state.init_or_migrate(l_epoch, r_epoch, epochs_per_frame=1000, consensus_version=1) + module.state = State( + { + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + }, + frame_1: { + ValidatorIndex(0): AttestationsAccumulator(included=200, assigned=200), # short on frame + ValidatorIndex(1): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(2): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(3): AttestationsAccumulator(included=999, assigned=1000), + ValidatorIndex(4): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(5): AttestationsAccumulator(included=500, assigned=1000), # underperforming + ValidatorIndex(6): AttestationsAccumulator(included=0, assigned=0), # underperforming + ValidatorIndex(7): AttestationsAccumulator(included=900, assigned=1000), + ValidatorIndex(8): AttestationsAccumulator(included=500, assigned=1000), # underperforming + # ValidatorIndex(9): AttestationsAggregate(included=0, assigned=0), # missing in state + ValidatorIndex(10): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(11): AttestationsAccumulator(included=1000, assigned=1000), + ValidatorIndex(12): AttestationsAccumulator(included=1000, assigned=1000), + }, + } + ) + module.w3.cc = Mock() + + module.converter = Mock( + side_effect=lambda _: Mock( + frame_config=FrameConfigFactory.build(epochs_per_frame=1000), + get_epoch_last_slot=lambda epoch: epoch * 32 + 31, + ) + ) + + module._get_ref_blockstamp_for_frame = Mock( + side_effect=[ + ReferenceBlockStampFactory.build( + slot_number=frame_0[1] * 32, ref_epoch=frame_0[1], ref_slot=frame_0[1] * 32 + ), + ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32), + ] + ) + + blockstamp = ReferenceBlockStampFactory.build(slot_number=r_epoch * 32, ref_epoch=r_epoch, ref_slot=r_epoch * 32) + distributed, shares, logs = module.calculate_distribution(blockstamp=blockstamp) + + assert distributed == 2 * 9_998 # because of the rounding + + assert tuple(shares.items()) == ( + (NodeOperatorId(0), 952), + (NodeOperatorId(1), 4761), + (NodeOperatorId(3), 4761), + (NodeOperatorId(6), 4761), + (NodeOperatorId(8), 4761), + ) + + assert len(logs) == 2 + + for log in logs: + + assert log.frame in module.state.data.keys() + assert log.threshold == module.state.get_network_aggr(log.frame).perf - 0.05 + + assert tuple(log.operators.keys()) == ( + NodeOperatorId(0), + NodeOperatorId(1), + NodeOperatorId(2), + NodeOperatorId(3), + NodeOperatorId(4), + NodeOperatorId(5), + NodeOperatorId(6), + # NodeOperatorId(7), # Missing in state + NodeOperatorId(8), + NodeOperatorId(9), + ) + + assert not log.operators[NodeOperatorId(1)].stuck + + assert log.operators[NodeOperatorId(2)].validators == {} + assert log.operators[NodeOperatorId(2)].stuck + assert log.operators[NodeOperatorId(4)].validators == {} + assert log.operators[NodeOperatorId(4)].stuck + + assert 5 in log.operators[NodeOperatorId(5)].validators + assert 6 in log.operators[NodeOperatorId(5)].validators + assert 7 in log.operators[NodeOperatorId(6)].validators + + assert log.operators[NodeOperatorId(0)].distributed == 476 + assert log.operators[NodeOperatorId(1)].distributed in [2380, 2381] + assert log.operators[NodeOperatorId(2)].distributed == 0 + assert log.operators[NodeOperatorId(3)].distributed in [2380, 2381] + assert log.operators[NodeOperatorId(6)].distributed in [2380, 2381] # Static functions you were dreaming of for so long. diff --git a/tests/modules/csm/test_log.py b/tests/modules/csm/test_log.py index de52ca9ef..61004e9ed 100644 --- a/tests/modules/csm/test_log.py +++ b/tests/modules/csm/test_log.py @@ -1,8 +1,7 @@ import json import pytest -from src.modules.csm.log import FramePerfLog -from src.modules.csm.state import AttestationsAccumulator +from src.modules.csm.log import FramePerfLog, AttestationsAccumulator from src.types import EpochNumber, NodeOperatorId, ReferenceBlockStamp from tests.factory.blockstamp import ReferenceBlockStampFactory @@ -33,16 +32,56 @@ def test_log_encode(log: FramePerfLog): log.operators[NodeOperatorId(42)].distributed = 17 log.operators[NodeOperatorId(0)].distributed = 0 - encoded = log.encode() + logs = [log] + + encoded = FramePerfLog.encode(logs) + + for decoded in json.loads(encoded): + assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 + assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded["operators"]["42"]["distributed"] == 17 + assert decoded["operators"]["0"]["distributed"] == 0 + + assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + + assert decoded["threshold"] == log.threshold + assert decoded["frame"] == list(log.frame) + + +def test_logs_encode(): + log_0 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(100), EpochNumber(500))) + log_0.operators[NodeOperatorId(42)].validators["41337"].perf = AttestationsAccumulator(220, 119) + log_0.operators[NodeOperatorId(42)].distributed = 17 + log_0.operators[NodeOperatorId(0)].distributed = 0 + + log_1 = FramePerfLog(ReferenceBlockStampFactory.build(), (EpochNumber(500), EpochNumber(900))) + log_1.operators[NodeOperatorId(5)].validators["1234"].perf = AttestationsAccumulator(400, 399) + log_1.operators[NodeOperatorId(5)].distributed = 40 + log_1.operators[NodeOperatorId(18)].distributed = 0 + + logs = [log_0, log_1] + + encoded = FramePerfLog.encode(logs) + decoded = json.loads(encoded) - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 - assert decoded["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 - assert decoded["operators"]["42"]["distributed"] == 17 - assert decoded["operators"]["0"]["distributed"] == 0 + assert len(decoded) == 2 + + assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["assigned"] == 220 + assert decoded[0]["operators"]["42"]["validators"]["41337"]["perf"]["included"] == 119 + assert decoded[0]["operators"]["42"]["distributed"] == 17 + assert decoded[0]["operators"]["0"]["distributed"] == 0 + + assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["assigned"] == 400 + assert decoded[1]["operators"]["5"]["validators"]["1234"]["perf"]["included"] == 399 + assert decoded[1]["operators"]["5"]["distributed"] == 40 + assert decoded[1]["operators"]["18"]["distributed"] == 0 - assert decoded["blockstamp"]["block_hash"] == log.blockstamp.block_hash - assert decoded["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot + for i, log in enumerate(logs): + assert decoded[i]["blockstamp"]["block_hash"] == log.blockstamp.block_hash + assert decoded[i]["blockstamp"]["ref_slot"] == log.blockstamp.ref_slot - assert decoded["threshold"] == log.threshold - assert decoded["frame"] == list(log.frame) + assert decoded[i]["threshold"] == log.threshold + assert decoded[i]["frame"] == list(log.frame) + assert decoded[i]["distributable"] == log.distributable diff --git a/tests/modules/csm/test_state.py b/tests/modules/csm/test_state.py index 7539f7d26..d781522e2 100644 --- a/tests/modules/csm/test_state.py +++ b/tests/modules/csm/test_state.py @@ -26,51 +26,43 @@ def test_attestation_aggregate_perf(): def test_state_avg_perf(): state = State() - assert state.get_network_aggr().perf == 0 + frame = (0, 999) - state = State( - { + with pytest.raises(ValueError): + state.get_network_aggr(frame) + + state = State() + state.init_or_migrate(*frame, 1000, 1) + state.data = { + frame: { ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=0), } - ) + } - assert state.get_network_aggr().perf == 0 + assert state.get_network_aggr(frame).perf == 0 - state = State( - { + state.data = { + frame: { ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), } - ) - - assert state.get_network_aggr().perf == 0.5 - + } -def test_state_frame(): - state = State() - - state.migrate(EpochNumber(100), EpochNumber(500), 1) - assert state.frame == (100, 500) - - state.migrate(EpochNumber(300), EpochNumber(301), 1) - assert state.frame == (300, 301) - - state.clear() - - with pytest.raises(ValueError, match="Epochs to process are not set"): - state.frame + assert state.get_network_aggr(frame).perf == 0.5 def test_state_attestations(): state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) - network_aggr = state.get_network_aggr() + network_aggr = state.get_network_aggr((0, 999)) assert network_aggr.assigned == 1000 assert network_aggr.included == 500 @@ -79,8 +71,10 @@ def test_state_attestations(): def test_state_load(): orig = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) @@ -92,8 +86,10 @@ def test_state_load(): def test_state_clear(): state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), - ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + (0, 999): { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + } } ) @@ -113,27 +109,42 @@ def test_state_add_processed_epoch(): def test_state_inc(): + + frame_0 = (0, 999) + frame_1 = (1000, 1999) + state = State( { - ValidatorIndex(0): AttestationsAccumulator(included=0, assigned=0), - ValidatorIndex(1): AttestationsAccumulator(included=1, assigned=2), + frame_0: { + ValidatorIndex(0): AttestationsAccumulator(included=333, assigned=777), + ValidatorIndex(1): AttestationsAccumulator(included=167, assigned=223), + }, + frame_1: { + ValidatorIndex(0): AttestationsAccumulator(included=1, assigned=1), + ValidatorIndex(1): AttestationsAccumulator(included=0, assigned=1), + }, } ) - state.inc(ValidatorIndex(0), True) - state.inc(ValidatorIndex(0), False) + state.increment_duty(999, ValidatorIndex(0), True) + state.increment_duty(999, ValidatorIndex(0), False) + state.increment_duty(999, ValidatorIndex(1), True) + state.increment_duty(999, ValidatorIndex(1), True) + state.increment_duty(999, ValidatorIndex(1), False) + state.increment_duty(999, ValidatorIndex(2), True) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), True) - state.inc(ValidatorIndex(1), False) + state.increment_duty(1000, ValidatorIndex(2), False) - state.inc(ValidatorIndex(2), True) - state.inc(ValidatorIndex(2), False) + assert tuple(state.data[frame_0].values()) == ( + AttestationsAccumulator(included=334, assigned=779), + AttestationsAccumulator(included=169, assigned=226), + AttestationsAccumulator(included=1, assigned=1), + ) - assert tuple(state.data.values()) == ( - AttestationsAccumulator(included=1, assigned=2), - AttestationsAccumulator(included=3, assigned=5), - AttestationsAccumulator(included=1, assigned=2), + assert tuple(state.data[frame_1].values()) == ( + AttestationsAccumulator(included=1, assigned=1), + AttestationsAccumulator(included=0, assigned=1), + AttestationsAccumulator(included=0, assigned=1), ) @@ -155,7 +166,7 @@ def test_empty_to_new_frame(self): l_epoch = EpochNumber(1) r_epoch = EpochNumber(255) - state.migrate(l_epoch, r_epoch, 1) + state.init_or_migrate(l_epoch, r_epoch, 255, 1) assert not state.is_empty assert state.unprocessed_epochs == set(sequence(l_epoch, r_epoch)) @@ -171,32 +182,60 @@ def test_empty_to_new_frame(self): def test_new_frame_requires_discarding_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): state = State() state.clear = Mock(side_effect=state.clear) - state.migrate(l_epoch_old, r_epoch_old, 1) + state.init_or_migrate(l_epoch_old, r_epoch_old, r_epoch_old - l_epoch_old + 1, 1) state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) + state.init_or_migrate(l_epoch_new, r_epoch_new, r_epoch_new - l_epoch_new + 1, 1) state.clear.assert_called_once() assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) @pytest.mark.parametrize( - ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new"), + ("l_epoch_old", "r_epoch_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame"), + [ + pytest.param(1, 255, 1, 510, 255, id="Migrate Aa..b..B"), + ], + ) + def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new, epochs_per_frame): + state = State() + state.clear = Mock(side_effect=state.clear) + + state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame, 1) + state.clear.assert_not_called() + + state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame, 1) + state.clear.assert_not_called() + + assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) + assert len(state.data) == 2 + assert list(state.data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)] + assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [ + (l_epoch_old, r_epoch_old), + (r_epoch_old + 1, r_epoch_new), + ] + + @pytest.mark.parametrize( + ("l_epoch_old", "r_epoch_old", "epochs_per_frame_old", "l_epoch_new", "r_epoch_new", "epochs_per_frame_new"), [ - pytest.param(1, 255, 1, 510, id="Migrate Aa..b..B"), - pytest.param(32, 510, 1, 510, id="Migrate: A..a..b..B"), + pytest.param(32, 510, 479, 1, 510, 510, id="Migrate: A..a..b..B"), ], ) - def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new, r_epoch_new): + def test_new_frame_extends_old_state_with_single_frame( + self, l_epoch_old, r_epoch_old, epochs_per_frame_old, l_epoch_new, r_epoch_new, epochs_per_frame_new + ): state = State() state.clear = Mock(side_effect=state.clear) - state.migrate(l_epoch_old, r_epoch_old, 1) + state.init_or_migrate(l_epoch_old, r_epoch_old, epochs_per_frame_old, 1) state.clear.assert_not_called() - state.migrate(l_epoch_new, r_epoch_new, 1) + state.init_or_migrate(l_epoch_new, r_epoch_new, epochs_per_frame_new, 1) state.clear.assert_not_called() assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new)) + assert len(state.data) == 1 + assert list(state.data.keys())[0] == (l_epoch_new, r_epoch_new) + assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)] @pytest.mark.parametrize( ("old_version", "new_version"), @@ -212,8 +251,8 @@ def test_consensus_version_change(self, old_version, new_version): l_epoch = r_epoch = EpochNumber(255) - state.migrate(l_epoch, r_epoch, old_version) + state.init_or_migrate(l_epoch, r_epoch, 1, old_version) state.clear.assert_not_called() - state.migrate(l_epoch, r_epoch, new_version) + state.init_or_migrate(l_epoch, r_epoch, 1, new_version) state.clear.assert_called_once()