From 87672acfe75b3b230d5a58ab6de9c5a8c2ba1031 Mon Sep 17 00:00:00 2001 From: El De-dog-lo <3859395+fubuloubu@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:14:17 -0400 Subject: [PATCH] refactor: allow multiple handlers (#66) * refactor: add TaskType enum to types * refactor!: use task collection by task type; use dynamic registration * refactor: use new task collections api for fetch task handlers * refactor!: use task type to differentiate tasks in middleware, not name * fix: typing bugs, not using Union for 3.8 support * fix: use task type to capture event logs * refactor: use defaultdict instead of custom collection type * refactor: use standardized labels, use task_name for task_id * refactor: remove `.task_name` from message labels * refactor: convert to TaskType for better processing * refactor: use StrEnum if available * docs: add note to deprecate in breaking change * refactor: make object type clearer when working with labels in middleware * refactor: use official backport * docs: update typing and add docs for dynamic broker task decorator fn * style: ignore mypy typing issues on <3.11 * refactor: avoid div/0 fault, fix duplicate log entry for results w/errs * refactor: rollback weird typing backport issue --- silverback/application.py | 127 +++++++++++++++++++------------------- silverback/exceptions.py | 12 ++-- silverback/middlewares.py | 85 ++++++++++++------------- silverback/runner.py | 26 ++++---- silverback/types.py | 24 ++++--- 5 files changed, 132 insertions(+), 142 deletions(-) diff --git a/silverback/application.py b/silverback/application.py index 75717c91..b04eea34 100644 --- a/silverback/application.py +++ b/silverback/application.py @@ -1,4 +1,6 @@ import atexit +from collections import defaultdict +from dataclasses import dataclass from datetime import timedelta from typing import Callable, Dict, Optional, Union @@ -6,12 +8,18 @@ from ape.contracts import ContractEvent, ContractInstance from ape.logging import logger from ape.managers.chain import BlockContainer -from ape.types import AddressType from ape.utils import ManagerAccessMixin from taskiq import AsyncTaskiqDecoratedTask, TaskiqEvents -from .exceptions import DuplicateHandlerError, InvalidContainerTypeError +from .exceptions import ContainerTypeMismatchError, InvalidContainerTypeError from .settings import Settings +from .types import TaskType + + +@dataclass +class TaskData: + container: Union[BlockContainer, ContractEvent, None] + handler: AsyncTaskiqDecoratedTask class SilverbackApp(ManagerAccessMixin): @@ -52,7 +60,8 @@ def __init__(self, settings: Optional[Settings] = None): logger.info(f"Loading Silverback App with settings:\n {settings_str}") self.broker = settings.get_broker() - self.contract_events: Dict[AddressType, Dict[str, ContractEvent]] = {} + # NOTE: If no tasks registered yet, defaults to empty list instead of raising KeyError + self.tasks: defaultdict[TaskType, list[TaskData]] = defaultdict(list) self.poll_settings: Dict[str, Dict] = {} atexit.register(self.network.__exit__, None, None, None) @@ -72,6 +81,54 @@ def __init__(self, settings: Optional[Settings] = None): f"{signer_str}{start_block_str}{new_block_timeout_str}" ) + def broker_task_decorator( + self, + task_type: TaskType, + container: Union[BlockContainer, ContractEvent, None] = None, + ) -> Callable[[Callable], AsyncTaskiqDecoratedTask]: + """ + Dynamically create a new broker task that handles tasks of ``task_type``. + + Args: + task_type: :class:`~silverback.types.TaskType`: The type of task to create. + container: (Union[BlockContainer, ContractEvent]): The event source to watch. + + Returns: + Callable[[Callable], :class:`~taskiq.AsyncTaskiqDecoratedTask`]: + A function wrapper that will register the task handler. + + Raises: + :class:`~silverback.exceptions.ContainerTypeMismatchError`: + If there is a mismatch between `task_type` and the `container` + type it should handle. + """ + if ( + (task_type is TaskType.NEW_BLOCKS and not isinstance(container, BlockContainer)) + or (task_type is TaskType.EVENT_LOG and not isinstance(container, ContractEvent)) + or ( + task_type + not in ( + TaskType.NEW_BLOCKS, + TaskType.EVENT_LOG, + ) + and container is not None + ) + ): + raise ContainerTypeMismatchError(task_type, container) + + # Register user function as task handler with our broker + def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask: + broker_task = self.broker.register_task( + handler, + task_name=handler.__name__, + task_type=str(task_type), + ) + + self.tasks[task_type].append(TaskData(container=container, handler=broker_task)) + return broker_task + + return add_taskiq_task + def on_startup(self) -> Callable: """ Code to execute on one worker upon startup / restart after an error. @@ -82,7 +139,7 @@ def on_startup(self) -> Callable: def do_something_on_startup(startup_state): ... # Reprocess missed events or blocks """ - return self.broker.task(task_name="silverback_startup") + return self.broker_task_decorator(TaskType.STARTUP) def on_shutdown(self) -> Callable: """ @@ -94,7 +151,7 @@ def on_shutdown(self) -> Callable: def do_something_on_shutdown(): ... # Record final state of app """ - return self.broker.task(task_name="silverback_shutdown") + return self.broker_task_decorator(TaskType.SHUTDOWN) def on_worker_startup(self) -> Callable: """ @@ -120,48 +177,6 @@ def do_something_on_shutdown(state): """ return self.broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) - def get_startup_handler(self) -> Optional[AsyncTaskiqDecoratedTask]: - """ - Get access to the handler for `silverback_startup` events. - - Returns: - Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created. - """ - return self.broker.find_task("silverback_startup") - - def get_shutdown_handler(self) -> Optional[AsyncTaskiqDecoratedTask]: - """ - Get access to the handler for `silverback_shutdown` events. - - Returns: - Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created. - """ - return self.broker.find_task("silverback_shutdown") - - def get_block_handler(self) -> Optional[AsyncTaskiqDecoratedTask]: - """ - Get access to the handler for `block` events. - - Returns: - Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created. - """ - return self.broker.find_task("block") - - def get_event_handler( - self, event_target: AddressType, event_name: str - ) -> Optional[AsyncTaskiqDecoratedTask]: - """ - Get access to the handler for `:` events. - - Args: - event_target (AddressType): The contract address of the target. - event_name: (str): The name of the event emitted by ``event_target``. - - Returns: - Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created. - """ - return self.broker.find_task(f"{event_target}/event/{event_name}") - def on_( self, container: Union[BlockContainer, ContractEvent], @@ -183,9 +198,6 @@ def on_( If the type of `container` is not configurable for the app. """ if isinstance(container, BlockContainer): - if self.get_block_handler(): - raise DuplicateHandlerError("block") - if new_block_timeout is not None: if "_blocks_" in self.poll_settings: self.poll_settings["_blocks_"]["new_block_timeout"] = new_block_timeout @@ -198,21 +210,12 @@ def on_( else: self.poll_settings["_blocks_"] = {"start_block": start_block} - return self.broker.task(task_name="block") + return self.broker_task_decorator(TaskType.NEW_BLOCKS, container=container) elif isinstance(container, ContractEvent) and isinstance( container.contract, ContractInstance ): - if self.get_event_handler(container.contract.address, container.abi.name): - raise DuplicateHandlerError( - f"event {container.contract.address}:{container.abi.name}" - ) - key = container.contract.address - if container.contract.address in self.contract_events: - self.contract_events[key][container.abi.name] = container - else: - self.contract_events[key] = {container.abi.name: container} if new_block_timeout is not None: if key in self.poll_settings: @@ -226,9 +229,7 @@ def on_( else: self.poll_settings[key] = {"start_block": start_block} - return self.broker.task( - task_name=f"{container.contract.address}/event/{container.abi.name}" - ) + return self.broker_task_decorator(TaskType.EVENT_LOG, container=container) # TODO: Support account transaction polling # TODO: Support mempool polling diff --git a/silverback/exceptions.py b/silverback/exceptions.py index 507cea19..125e85a0 100644 --- a/silverback/exceptions.py +++ b/silverback/exceptions.py @@ -3,21 +3,23 @@ from ape.exceptions import ApeException from ape.logging import logger +from .types import TaskType + class ImportFromStringError(Exception): pass -class DuplicateHandlerError(Exception): - def __init__(self, handler_type: str): - super().__init__(f"Only one handler allowed for: {handler_type}") - - class InvalidContainerTypeError(Exception): def __init__(self, container: Any): super().__init__(f"Invalid container type: {container.__class__}") +class ContainerTypeMismatchError(Exception): + def __init__(self, task_type: TaskType, container: Any): + super().__init__(f"Invalid container type for '{task_type}': {container.__class__}") + + class NoWebsocketAvailableError(Exception): def __init__(self): super().__init__( diff --git a/silverback/middlewares.py b/silverback/middlewares.py index 4e713d73..18c6c72f 100644 --- a/silverback/middlewares.py +++ b/silverback/middlewares.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple +from typing import Any from ape.logging import logger from ape.types import ContractLog @@ -7,27 +7,10 @@ from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult from silverback.persistence import HandlerResult -from silverback.types import SilverbackID, handler_id_block, handler_id_event +from silverback.types import SilverbackID, TaskType from silverback.utils import hexbytes_dict -def resolve_task(message: TaskiqMessage) -> Tuple[str, Optional[int], Optional[int]]: - block_number = None - log_index = None - task_id = message.task_name - - if task_id == "block": - block_number = message.args[0].number - task_id = handler_id_block(block_number) - elif "event" in task_id: - block_number = message.args[0].block_number - log_index = message.args[0].log_index - # TODO: Should standardize on event signature here instead of name in case of overloading - task_id = handler_id_event(message.args[0].contract_address, message.args[0].event_name) - - return task_id, block_number, log_index - - class SilverbackMiddleware(TaskiqMiddleware, ManagerAccessMixin): def __init__(self, *args, **kwargs): def compute_block_time() -> int: @@ -66,46 +49,62 @@ def fix_dict(data: dict, recurse_count: int = 0) -> dict: return message def _create_label(self, message: TaskiqMessage) -> str: - if message.task_name == "block": - args = f"[block={message.args[0].hash.hex()}]" - - elif "event" in message.task_name: - args = f"[txn={message.args[0].transaction_hash},log_index={message.args[0].log_index}]" + if labels_str := ",".join(f"{k}={v}" for k, v in message.labels.items()): + return f"{message.task_name}[{labels_str}]" else: - args = "" - - return f"{message.task_name}{args}" + return message.task_name def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: - if message.task_name == "block": + if not (task_type := message.labels.pop("task_type")): + return message # Not a silverback task + + try: + task_type = TaskType(task_type) + except ValueError: + return message # Not a silverback task + + # Add extra labels for our task to see what their source was + if task_type is TaskType.NEW_BLOCKS: # NOTE: Necessary because we don't know the exact block class - message.args[0] = self.provider.network.ecosystem.decode_block( + block = message.args[0] = self.provider.network.ecosystem.decode_block( hexbytes_dict(message.args[0]) ) + message.labels["block_number"] = str(block.number) + message.labels["block_hash"] = block.hash.hex() - elif "event" in message.task_name: + elif task_type is TaskType.EVENT_LOG: # NOTE: Just in case the user doesn't specify type as `ContractLog` - message.args[0] = ContractLog.model_validate(message.args[0]) + log = message.args[0] = ContractLog.model_validate(message.args[0]) + message.labels["block_number"] = str(log.block_number) + message.labels["transaction_hash"] = log.transaction_hash + message.labels["log_index"] = str(log.log_index) - logger.info(f"{self._create_label(message)} - Started") + logger.debug(f"{self._create_label(message)} - Started") return message def post_execute(self, message: TaskiqMessage, result: TaskiqResult): - percentage_time = 100 * (result.execution_time / self.block_time) - logger.info( - f"{self._create_label(message)} " - f"- {result.execution_time:.3f}s ({percentage_time:.1f}%)" + if self.block_time: + percentage_time = 100 * (result.execution_time / self.block_time) + percent_display = f" ({percentage_time:.1f}%)" + + else: + percent_display = "" + + (logger.error if result.error else logger.success)( + f"{self._create_label(message)} " f"- {result.execution_time:.3f}s{percent_display}" ) async def post_save(self, message: TaskiqMessage, result: TaskiqResult): if not self.persistence: return - handler_id, block_number, log_index = resolve_task(message) - handler_result = HandlerResult.from_taskiq( - self.ident, handler_id, block_number, log_index, result + self.ident, + message.task_name, + message.labels.get("block_number"), + message.labels.get("log_index"), + result, ) try: @@ -113,10 +112,4 @@ async def post_save(self, message: TaskiqMessage, result: TaskiqResult): except Exception as err: logger.error(f"Error storing result: {err}") - async def on_error( - self, - message: TaskiqMessage, - result: TaskiqResult, - exception: BaseException, - ): - logger.error(f"{message.task_name} - {type(exception).__name__}: {exception}") + # NOTE: Unless stdout is ignored, error traceback appears in stdout, no need for `on_error` diff --git a/silverback/runner.py b/silverback/runner.py index 042841ad..8f014b4e 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -14,7 +14,7 @@ from .persistence import BasePersistentStore from .settings import Settings from .subscriptions import SubscriptionType, Web3SubscriptionsManager -from .types import SilverbackID, SilverbackStartupState +from .types import SilverbackID, SilverbackStartupState, TaskType from .utils import async_wrap_iter, hexbytes_dict settings = Settings() @@ -103,8 +103,8 @@ async def run(self): await self.app.broker.startup() # Execute Silverback startup task before we init the rest - if startup_handler := self.app.get_startup_handler(): - task = await startup_handler.kiq( + for startup_task in self.app.tasks[TaskType.STARTUP]: + task = await startup_task.handler.kiq( SilverbackStartupState( last_block_seen=self.last_block_seen, last_block_processed=self.last_block_processed, @@ -113,15 +113,12 @@ async def run(self): result = await task.wait_result() self._handle_result(result) - if block_handler := self.app.get_block_handler(): - tasks = [self._block_task(block_handler)] - else: - tasks = [] + tasks = [] + for task in self.app.tasks[TaskType.NEW_BLOCKS]: + tasks.append(self._block_task(task.handler)) - for contract_address in self.app.contract_events: - for event_name, contract_event in self.app.contract_events[contract_address].items(): - if event_handler := self.app.get_event_handler(contract_address, event_name): - tasks.append(self._event_task(contract_event, event_handler)) + for task in self.app.tasks[TaskType.EVENT_LOG]: + tasks.append(self._event_task(task.container, task.handler)) if len(tasks) == 0: raise Halt("No tasks to execute") @@ -132,10 +129,9 @@ async def run(self): logger.error(f"Fatal error detected, shutting down: '{e}'") # Execute Silverback shutdown task before shutting down the broker - if shutdown_handler := self.app.get_shutdown_handler(): - task = await shutdown_handler.kiq() - result = await task.wait_result() - self._handle_result(result) + for shutdown_task in self.app.tasks[TaskType.SHUTDOWN]: + task = await shutdown_task.handler.kiq() + result = self._handle_result(await task.wait_result()) await self.app.broker.shutdown() diff --git a/silverback/types.py b/silverback/types.py index c502ed25..542f77a5 100644 --- a/silverback/types.py +++ b/silverback/types.py @@ -1,9 +1,20 @@ +from enum import Enum # NOTE: `enum.StrEnum` only in Python 3.11+ from typing import Optional, Protocol from pydantic import BaseModel from typing_extensions import Self # Introduced 3.11 +class TaskType(str, Enum): + STARTUP = "silverback_startup" # TODO: Shorten in 0.4.0 + NEW_BLOCKS = "block" + EVENT_LOG = "event" + SHUTDOWN = "silverback_shutdown" # TODO: Shorten in 0.4.0 + + def __str__(self) -> str: + return self.value + + class ISilverbackSettings(Protocol): """Loose approximation of silverback.settings.Settings. If you can, use the class as a type reference.""" @@ -27,16 +38,3 @@ def from_settings(cls, settings_: ISilverbackSettings) -> Self: class SilverbackStartupState(BaseModel): last_block_seen: int last_block_processed: int - - -def handler_id_block(block_number: Optional[int]) -> str: - """Return a unique handler ID string for a block""" - if block_number is None: - return "block/pending" - return f"block/{block_number}" - - -def handler_id_event(contract_address: Optional[str], event_signature: str) -> str: - """Return a unique handler ID string for an event""" - # TODO: Under what circumstance can address be None? - return f"{contract_address or 'unknown'}/event/{event_signature}"