Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Remove extra data v1 #602

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions src/modules/accounting/accounting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from src import variables
from src.constants import SHARE_RATE_PRECISION_E27
from src.modules.accounting.third_phase.extra_data import ExtraDataService
from src.modules.accounting.third_phase.extra_data_v2 import ExtraDataServiceV2
from src.modules.accounting.third_phase.types import ExtraData, FormatList
from src.modules.accounting.types import (
ReportData,
Expand Down Expand Up @@ -333,24 +332,14 @@ def _is_bunker(self, blockstamp: ReferenceBlockStamp) -> BunkerMode:

@lru_cache(maxsize=1)
def get_extra_data(self, blockstamp: ReferenceBlockStamp) -> ExtraData:
consensus_version = self.w3.lido_contracts.accounting_oracle.get_consensus_version(blockstamp.block_hash)

chain_config = self.get_chain_config(blockstamp)
stuck_validators = self.lido_validator_state_service.get_lido_newly_stuck_validators(blockstamp, chain_config)
logger.info({'msg': 'Calculate stuck validators.', 'value': stuck_validators})
exited_validators = self.lido_validator_state_service.get_lido_newly_exited_validators(blockstamp)
logger.info({'msg': 'Calculate exited validators.', 'value': exited_validators})
orl = self.w3.lido_contracts.oracle_report_sanity_checker.get_oracle_report_limits(blockstamp.block_hash)

if consensus_version == 1:
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
orl.max_node_operators_per_extra_data_item,
)

return ExtraDataServiceV2.collect(
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
Expand Down Expand Up @@ -383,7 +372,7 @@ def _calculate_wq_report(self, blockstamp: ReferenceBlockStamp) -> WqReport:

def _calculate_extra_data_report(self, blockstamp: ReferenceBlockStamp) -> ExtraData:
stuck_validators, exited_validators, orl = self._get_generic_extra_data(blockstamp)
return ExtraDataServiceV2.collect(
return ExtraDataService.collect(
stuck_validators,
exited_validators,
orl.max_items_per_extra_data_transaction,
Expand Down
159 changes: 82 additions & 77 deletions src/modules/accounting/third_phase/extra_data.py
Original file line number Diff line number Diff line change
@@ -1,133 +1,138 @@
import itertools
from dataclasses import dataclass
from itertools import groupby, batched
from typing import Sequence

from hexbytes import HexBytes

from src.modules.accounting.third_phase.types import ItemType, ExtraData, FormatList, ExtraDataLengths
from src.modules.accounting.third_phase.types import ExtraData, ItemType, ExtraDataLengths, FormatList
from src.modules.submodules.types import ZERO_HASH
from src.types import NodeOperatorGlobalIndex
from src.web3py.types import Web3


@dataclass
class ItemPayload:
module_id: bytes
node_ops_count: bytes
node_operator_ids: bytes
vals_counts: bytes


@dataclass
class ExtraDataItem:
item_index: bytes
item_type: ItemType
item_payload: ItemPayload
module_id: int
node_operator_ids: Sequence[int]
vals_counts: Sequence[int]


class ExtraDataService:
"""
Service that encodes extra data into bytes in correct order.
Extra data is an array of items, each item being encoded as follows:
| 3 bytes | 2 bytes | X bytes |
| itemIndex | itemType | itemPayload |
| 32 bytes | 3 bytes | 2 bytes | X bytes |
| nextHash | itemIndex | itemType | itemPayload |
itemPayload format:
| 3 bytes | 8 bytes | nodeOpsCount * 8 bytes | nodeOpsCount * 16 bytes |
| moduleId | nodeOpsCount | nodeOperatorIds | stuckOrExitedValsCount |
max_items_count - max itemIndex in extra data.
max_items_count_per_tx - max itemIndex in extra data.
max_no_in_payload_count - max nodeOpsCount that could be used in itemPayload.
"""
@classmethod
def collect(
cls,
stuck_validators: dict[NodeOperatorGlobalIndex, int],
exited_validators: dict[NodeOperatorGlobalIndex, int],
max_items_count: int,
max_items_count_per_tx: int,
max_no_in_payload_count: int,
) -> ExtraData:
stuck_payloads = cls.build_validators_payloads(stuck_validators, max_no_in_payload_count)
exited_payloads = cls.build_validators_payloads(exited_validators, max_no_in_payload_count)
items_count, txs = cls.build_extra_transactions_data(stuck_payloads, exited_payloads, max_items_count_per_tx)
first_hash, hashed_txs = cls.add_hashes_to_transactions(txs)

extra_data = cls.build_extra_data(stuck_payloads, exited_payloads, max_items_count)
extra_data_bytes = cls.to_bytes(extra_data)

if extra_data:
extra_data_list = [extra_data_bytes]
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
data_hash = Web3.keccak(extra_data_bytes)
if items_count:
extra_data_format = FormatList.EXTRA_DATA_FORMAT_LIST_NON_EMPTY
else:
extra_data_list = []
data_format = FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY
data_hash = HexBytes(ZERO_HASH)
extra_data_format = FormatList.EXTRA_DATA_FORMAT_LIST_EMPTY

return ExtraData(
extra_data_list=extra_data_list,
data_hash=data_hash,
format=data_format.value,
items_count=len(extra_data),
items_count=items_count,
extra_data_list=hashed_txs,
data_hash=first_hash,
format=extra_data_format.value,
)

@staticmethod
@classmethod
def build_validators_payloads(
cls,
validators: dict[NodeOperatorGlobalIndex, int],
max_no_in_payload_count: int,
) -> list[ItemPayload]:
# sort by module id and node operator id
operator_validators = sorted(validators.items(), key=lambda x: x[0])

payloads = []

for module_id, operators_by_module in itertools.groupby(operator_validators, key=lambda x: x[0][0]):
operator_ids = []
vals_count = []

for ((_, no_id), validators_count) in list(operators_by_module)[:max_no_in_payload_count]:
operator_ids.append(no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_ID))
vals_count.append(validators_count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT))

payloads.append(
ItemPayload(
module_id=module_id.to_bytes(ExtraDataLengths.MODULE_ID),
node_ops_count=len(operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT),
node_operator_ids=b"".join(operator_ids),
vals_counts=b"".join(vals_count),
for module_id, operators_by_module in groupby(operator_validators, key=lambda x: x[0][0]):
for nos_in_batch in batched(list(operators_by_module), max_no_in_payload_count):
operator_ids = []
vals_count = []

for ((_, no_id), validators_count) in nos_in_batch:
operator_ids.append(no_id)
vals_count.append(validators_count)

payloads.append(
ItemPayload(
module_id=module_id,
node_operator_ids=operator_ids,
vals_counts=vals_count,
)
)
)

return payloads

@staticmethod
def build_extra_data(stuck_payloads: list[ItemPayload], exited_payloads: list[ItemPayload], max_items_count: int):
@classmethod
def build_extra_transactions_data(
cls,
stuck_payloads: list[ItemPayload],
exited_payloads: list[ItemPayload],
max_items_count_per_tx: int,
) -> tuple[int, list[bytes]]:
all_payloads = [
*[(ItemType.EXTRA_DATA_TYPE_STUCK_VALIDATORS, payload) for payload in stuck_payloads],
*[(ItemType.EXTRA_DATA_TYPE_EXITED_VALIDATORS, payload) for payload in exited_payloads],
]

index = 0
extra_data = []

for item_type, payloads in [
(ItemType.EXTRA_DATA_TYPE_STUCK_VALIDATORS, stuck_payloads),
(ItemType.EXTRA_DATA_TYPE_EXITED_VALIDATORS, exited_payloads),
]:
for payload in payloads:
extra_data.append(ExtraDataItem(
item_index=index.to_bytes(ExtraDataLengths.ITEM_INDEX),
item_type=item_type,
item_payload=payload
))
result = []

for payload_batch in batched(all_payloads, max_items_count_per_tx):
tx_body = b''
for item_type, payload in payload_batch:
tx_body += index.to_bytes(ExtraDataLengths.ITEM_INDEX)
tx_body += item_type.value.to_bytes(ExtraDataLengths.ITEM_TYPE)
tx_body += payload.module_id.to_bytes(ExtraDataLengths.MODULE_ID)
tx_body += len(payload.node_operator_ids).to_bytes(ExtraDataLengths.NODE_OPS_COUNT)
tx_body += b''.join(
no_id.to_bytes(ExtraDataLengths.NODE_OPERATOR_ID)
for no_id in payload.node_operator_ids
)
tx_body += b''.join(
count.to_bytes(ExtraDataLengths.STUCK_OR_EXITED_VALS_COUNT)
for count in payload.vals_counts
)

index += 1
if index == max_items_count:
return extra_data

return extra_data
result.append(tx_body)

return index, result

@staticmethod
def to_bytes(extra_data: list[ExtraDataItem]) -> bytes:
extra_data_bytes = b''
for item in extra_data:
extra_data_bytes += item.item_index
extra_data_bytes += item.item_type.value.to_bytes(ExtraDataLengths.ITEM_TYPE)
extra_data_bytes += item.item_payload.module_id
extra_data_bytes += item.item_payload.node_ops_count
extra_data_bytes += item.item_payload.node_operator_ids
extra_data_bytes += item.item_payload.vals_counts
return extra_data_bytes
def add_hashes_to_transactions(txs_data: list[bytes]) -> tuple[bytes, list[bytes]]:
txs_data.reverse()

txs_with_hashes = []
next_hash = ZERO_HASH

for tx in txs_data:
full_tx_data = next_hash + tx
txs_with_hashes.append(full_tx_data)
next_hash = Web3.keccak(full_tx_data)

txs_with_hashes.reverse()

return next_hash, txs_with_hashes
138 changes: 0 additions & 138 deletions src/modules/accounting/third_phase/extra_data_v2.py

This file was deleted.

Loading
Loading