Skip to content

Commit

Permalink
fix: linter
Browse files Browse the repository at this point in the history
  • Loading branch information
vgorkavenko committed Feb 3, 2025
1 parent eead54b commit 684a40f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/modules/checks/suites/consensus_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ def check_attestation_committees(web3: Web3, blockstamp):
assert web3.cc.get_attestation_committees(blockstamp, epoch), "consensus-client provide no attestation committees"


def check_block_attestations(web3: Web3, blockstamp):
def check_block_attestations_and_sync(web3: Web3, blockstamp):
"""Check that consensus-client able to provide block attestations"""
assert web3.cc.get_block_attestations(blockstamp.slot_number), "consensus-client provide no block attestations"
assert web3.cc.get_block_attestations_and_sync(blockstamp.slot_number), "consensus-client provide no block attestations and sync"
5 changes: 2 additions & 3 deletions src/modules/csm/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def _is_min_step_reached(self):
return False


type CommitteeIndex = str
type SlotBlockRoot = tuple[SlotNumber, BlockRoot | None]
type SyncCommittees = dict[SlotNumber, list[ValidatorDuty]]
type AttestationCommittees = dict[tuple[SlotNumber, CommitteeIndex], list[ValidatorDuty]]
Expand Down Expand Up @@ -250,7 +249,7 @@ def _check_duties(
missed_slot = root is None
if missed_slot:
continue
attestations, sync_aggregate = self.cc.get_block_attestations_and_sync(BlockRoot(root))
attestations, sync_aggregate = self.cc.get_block_attestations_and_sync(root)
process_attestations(attestations, att_committees, self.eip7549_supported)
if (slot, root) in duty_epoch_roots:
propose_duties[slot].included = True
Expand Down Expand Up @@ -384,7 +383,7 @@ def _get_dependent_root_for_proposer_duties(
}
)
break
dependent_slot -= 1
dependent_slot = SlotNumber(int(dependent_slot - 1))
except SlotOutOfRootsRange:
dependent_non_missed_slot = SlotNumber(int(
get_prev_non_missed_slot(
Expand Down
3 changes: 1 addition & 2 deletions src/modules/csm/csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ def calculate_distribution(
shares = defaultdict[NodeOperatorId, int](int)
logs: list[FramePerfLog] = []

frames = self.state.calculate_frames(self.state._epochs_to_process, self.state._epochs_per_frame)
for frame in frames:
for frame in self.state.frames:
from_epoch, to_epoch = frame
logger.info({"msg": f"Calculating distribution for frame [{from_epoch};{to_epoch}]"})
frame_blockstamp = blockstamp
Expand Down
38 changes: 20 additions & 18 deletions src/modules/csm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def add_duty(self, included: bool) -> None:
self.included += 1 if included else 0


def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
frames = []
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
if len(frame_epochs) < epochs_per_frame:
raise ValueError("Insufficient epochs to form a frame")
frames.append((frame_epochs[0], frame_epochs[-1]))
return frames


class State:
"""
Processing state of a CSM performance oracle frame.
Expand Down Expand Up @@ -129,11 +139,10 @@ def clear(self) -> None:
assert self.is_empty

def find_frame(self, epoch: EpochNumber) -> Frame:
frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
for epoch_range in frames:
for epoch_range in self.frames:
if epoch_range[0] <= epoch <= epoch_range[1]:
return epoch_range
raise ValueError(f"Epoch {epoch} is out of frames range: {frames}")
raise ValueError(f"Epoch {epoch} is out of frames range: {self.frames}")

def increment_att_duty(self, frame: Frame, val_index: ValidatorIndex, included: bool) -> None:
self.att_data[frame][val_index].add_duty(included)
Expand Down Expand Up @@ -178,16 +187,15 @@ def init_or_migrate(
self.commit()

def _fill_frames(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> None:
frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
frames = calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
for frame in frames:
self.att_data.setdefault(frame, defaultdict(DutyAccumulator))
self.prop_data.setdefault(frame, defaultdict(DutyAccumulator))
self.sync_data.setdefault(frame, defaultdict(DutyAccumulator))

def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epochs_per_frame: int) -> bool:
current_frames = self.calculate_frames(self._epochs_to_process, self._epochs_per_frame)
new_frames = self.calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
inv_msg = f"Discarding invalid state cache because of frames change. {current_frames=}, {new_frames=}"
new_frames = calculate_frames(tuple(sequence(l_epoch, r_epoch)), epochs_per_frame)
inv_msg = f"Discarding invalid state cache because of frames change. {self.frames=}, {new_frames=}"

if self._invalidate_on_epoch_range_change(l_epoch, r_epoch):
logger.warning({"msg": inv_msg})
Expand All @@ -196,10 +204,10 @@ def _migrate_or_invalidate(self, l_epoch: EpochNumber, r_epoch: EpochNumber, epo
frame_expanded = epochs_per_frame > self._epochs_per_frame
frame_shrunk = epochs_per_frame < self._epochs_per_frame

has_single_frame = len(current_frames) == len(new_frames) == 1
has_single_frame = len(self.frames) == len(new_frames) == 1

if has_single_frame and frame_expanded:
current_frame, *_ = current_frames
current_frame, *_ = self.frames
new_frame, *_ = new_frames
self.att_data[new_frame] = self.att_data.pop(current_frame)
self.prop_data[new_frame] = self.prop_data.pop(current_frame)
Expand Down Expand Up @@ -236,15 +244,9 @@ def validate(self, l_epoch: EpochNumber, r_epoch: EpochNumber) -> None:
if epoch not in self._processed_epochs:
raise InvalidState(f"Epoch {epoch} missing in processed epochs")

@staticmethod
def calculate_frames(epochs_to_process: tuple[EpochNumber, ...], epochs_per_frame: int) -> list[Frame]:
"""Split epochs to process into frames of `epochs_per_frame` length"""
frames = []
for frame_epochs in batched(epochs_to_process, epochs_per_frame):
if len(frame_epochs) < epochs_per_frame:
raise ValueError("Insufficient epochs to form a frame")
frames.append((frame_epochs[0], frame_epochs[-1]))
return frames
@property
def frames(self) -> list[Frame]:
return calculate_frames(self._epochs_to_process, self._epochs_per_frame)

def get_att_network_aggr(self, frame: Frame) -> DutyAccumulator:
# TODO: exclude `active_slashed` validators from the calculation
Expand Down
2 changes: 1 addition & 1 deletion src/providers/consensus/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def get_block_attestations_and_sync(self, state_id: SlotNumber | BlockRoot) -> t
attestations = [BlockAttestationResponse.from_response(**att) for att in data["message"]["body"]["attestations"]]
sync = SyncAggregate.from_response(**data["message"]["body"]["sync_aggregate"])

return attestations, sync
return cast(list[BlockAttestation], attestations), sync

@list_of_dataclasses(SlotAttestationCommittee.from_response)
def get_attestation_committees(
Expand Down
6 changes: 3 additions & 3 deletions tests/modules/csm/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from src.modules.csm.state import DutyAccumulator, State
from src.modules.csm.state import DutyAccumulator, State, calculate_frames
from src.types import EpochNumber, ValidatorIndex
from src.utils.range import sequence

Expand Down Expand Up @@ -209,7 +209,7 @@ def test_new_frame_extends_old_state(self, l_epoch_old, r_epoch_old, l_epoch_new
assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
assert len(state.att_data) == 2
assert list(state.att_data.keys()) == [(l_epoch_old, r_epoch_old), (r_epoch_old + 1, r_epoch_new)]
assert state.calculate_frames(state._epochs_to_process, epochs_per_frame) == [
assert calculate_frames(state._epochs_to_process, epochs_per_frame) == [
(l_epoch_old, r_epoch_old),
(r_epoch_old + 1, r_epoch_new),
]
Expand All @@ -235,7 +235,7 @@ def test_new_frame_extends_old_state_with_single_frame(
assert state.unprocessed_epochs == set(sequence(l_epoch_new, r_epoch_new))
assert len(state.att_data) == 1
assert list(state.att_data.keys())[0] == (l_epoch_new, r_epoch_new)
assert state.calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)]
assert calculate_frames(state._epochs_to_process, epochs_per_frame_new) == [(l_epoch_new, r_epoch_new)]

@pytest.mark.parametrize(
("old_version", "new_version"),
Expand Down

0 comments on commit 684a40f

Please sign in to comment.