Skip to content

Commit

Permalink
refactor: allow multiple handlers (#66)
Browse files Browse the repository at this point in the history
* refactor: add TaskType enum to types

* refactor!: use task collection by task type; use dynamic registration

* refactor: use new task collections api for fetch task handlers

* refactor!: use task type to differentiate tasks in middleware, not name

* fix: typing bugs, not using Union for 3.8 support

* fix: use task type to capture event logs

* refactor: use defaultdict instead of custom collection type

* refactor: use standardized labels, use task_name for task_id

* refactor: remove `.task_name` from message labels

* refactor: convert to TaskType for better processing

* refactor: use StrEnum if available

* docs: add note to deprecate in breaking change

* refactor: make object type clearer when working with labels in middleware

* refactor: use official backport

* docs: update typing and add docs for dynamic broker task decorator fn

* style: ignore mypy typing issues on <3.11

* refactor: avoid div/0 fault, fix duplicate log entry for results w/errs

* refactor: rollback weird typing backport issue
  • Loading branch information
fubuloubu authored Apr 11, 2024
1 parent 4a1f70b commit 87672ac
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 142 deletions.
127 changes: 64 additions & 63 deletions silverback/application.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import atexit
from collections import defaultdict
from dataclasses import dataclass
from datetime import timedelta
from typing import Callable, Dict, Optional, Union

from ape.api.networks import LOCAL_NETWORK_NAME
from ape.contracts import ContractEvent, ContractInstance
from ape.logging import logger
from ape.managers.chain import BlockContainer
from ape.types import AddressType
from ape.utils import ManagerAccessMixin
from taskiq import AsyncTaskiqDecoratedTask, TaskiqEvents

from .exceptions import DuplicateHandlerError, InvalidContainerTypeError
from .exceptions import ContainerTypeMismatchError, InvalidContainerTypeError
from .settings import Settings
from .types import TaskType


@dataclass
class TaskData:
container: Union[BlockContainer, ContractEvent, None]
handler: AsyncTaskiqDecoratedTask


class SilverbackApp(ManagerAccessMixin):
Expand Down Expand Up @@ -52,7 +60,8 @@ def __init__(self, settings: Optional[Settings] = None):
logger.info(f"Loading Silverback App with settings:\n {settings_str}")

self.broker = settings.get_broker()
self.contract_events: Dict[AddressType, Dict[str, ContractEvent]] = {}
# NOTE: If no tasks registered yet, defaults to empty list instead of raising KeyError
self.tasks: defaultdict[TaskType, list[TaskData]] = defaultdict(list)
self.poll_settings: Dict[str, Dict] = {}

atexit.register(self.network.__exit__, None, None, None)
Expand All @@ -72,6 +81,54 @@ def __init__(self, settings: Optional[Settings] = None):
f"{signer_str}{start_block_str}{new_block_timeout_str}"
)

def broker_task_decorator(
self,
task_type: TaskType,
container: Union[BlockContainer, ContractEvent, None] = None,
) -> Callable[[Callable], AsyncTaskiqDecoratedTask]:
"""
Dynamically create a new broker task that handles tasks of ``task_type``.
Args:
task_type: :class:`~silverback.types.TaskType`: The type of task to create.
container: (Union[BlockContainer, ContractEvent]): The event source to watch.
Returns:
Callable[[Callable], :class:`~taskiq.AsyncTaskiqDecoratedTask`]:
A function wrapper that will register the task handler.
Raises:
:class:`~silverback.exceptions.ContainerTypeMismatchError`:
If there is a mismatch between `task_type` and the `container`
type it should handle.
"""
if (
(task_type is TaskType.NEW_BLOCKS and not isinstance(container, BlockContainer))
or (task_type is TaskType.EVENT_LOG and not isinstance(container, ContractEvent))
or (
task_type
not in (
TaskType.NEW_BLOCKS,
TaskType.EVENT_LOG,
)
and container is not None
)
):
raise ContainerTypeMismatchError(task_type, container)

# Register user function as task handler with our broker
def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask:
broker_task = self.broker.register_task(
handler,
task_name=handler.__name__,
task_type=str(task_type),
)

self.tasks[task_type].append(TaskData(container=container, handler=broker_task))
return broker_task

return add_taskiq_task

def on_startup(self) -> Callable:
"""
Code to execute on one worker upon startup / restart after an error.
Expand All @@ -82,7 +139,7 @@ def on_startup(self) -> Callable:
def do_something_on_startup(startup_state):
... # Reprocess missed events or blocks
"""
return self.broker.task(task_name="silverback_startup")
return self.broker_task_decorator(TaskType.STARTUP)

def on_shutdown(self) -> Callable:
"""
Expand All @@ -94,7 +151,7 @@ def on_shutdown(self) -> Callable:
def do_something_on_shutdown():
... # Record final state of app
"""
return self.broker.task(task_name="silverback_shutdown")
return self.broker_task_decorator(TaskType.SHUTDOWN)

def on_worker_startup(self) -> Callable:
"""
Expand All @@ -120,48 +177,6 @@ def do_something_on_shutdown(state):
"""
return self.broker.on_event(TaskiqEvents.WORKER_SHUTDOWN)

def get_startup_handler(self) -> Optional[AsyncTaskiqDecoratedTask]:
"""
Get access to the handler for `silverback_startup` events.
Returns:
Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created.
"""
return self.broker.find_task("silverback_startup")

def get_shutdown_handler(self) -> Optional[AsyncTaskiqDecoratedTask]:
"""
Get access to the handler for `silverback_shutdown` events.
Returns:
Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created.
"""
return self.broker.find_task("silverback_shutdown")

def get_block_handler(self) -> Optional[AsyncTaskiqDecoratedTask]:
"""
Get access to the handler for `block` events.
Returns:
Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created.
"""
return self.broker.find_task("block")

def get_event_handler(
self, event_target: AddressType, event_name: str
) -> Optional[AsyncTaskiqDecoratedTask]:
"""
Get access to the handler for `<event_target>:<event_name>` events.
Args:
event_target (AddressType): The contract address of the target.
event_name: (str): The name of the event emitted by ``event_target``.
Returns:
Optional[AsyncTaskiqDecoratedTask]: Returns decorated task, if one has been created.
"""
return self.broker.find_task(f"{event_target}/event/{event_name}")

def on_(
self,
container: Union[BlockContainer, ContractEvent],
Expand All @@ -183,9 +198,6 @@ def on_(
If the type of `container` is not configurable for the app.
"""
if isinstance(container, BlockContainer):
if self.get_block_handler():
raise DuplicateHandlerError("block")

if new_block_timeout is not None:
if "_blocks_" in self.poll_settings:
self.poll_settings["_blocks_"]["new_block_timeout"] = new_block_timeout
Expand All @@ -198,21 +210,12 @@ def on_(
else:
self.poll_settings["_blocks_"] = {"start_block": start_block}

return self.broker.task(task_name="block")
return self.broker_task_decorator(TaskType.NEW_BLOCKS, container=container)

elif isinstance(container, ContractEvent) and isinstance(
container.contract, ContractInstance
):
if self.get_event_handler(container.contract.address, container.abi.name):
raise DuplicateHandlerError(
f"event {container.contract.address}:{container.abi.name}"
)

key = container.contract.address
if container.contract.address in self.contract_events:
self.contract_events[key][container.abi.name] = container
else:
self.contract_events[key] = {container.abi.name: container}

if new_block_timeout is not None:
if key in self.poll_settings:
Expand All @@ -226,9 +229,7 @@ def on_(
else:
self.poll_settings[key] = {"start_block": start_block}

return self.broker.task(
task_name=f"{container.contract.address}/event/{container.abi.name}"
)
return self.broker_task_decorator(TaskType.EVENT_LOG, container=container)

# TODO: Support account transaction polling
# TODO: Support mempool polling
Expand Down
12 changes: 7 additions & 5 deletions silverback/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
from ape.exceptions import ApeException
from ape.logging import logger

from .types import TaskType


class ImportFromStringError(Exception):
pass


class DuplicateHandlerError(Exception):
def __init__(self, handler_type: str):
super().__init__(f"Only one handler allowed for: {handler_type}")


class InvalidContainerTypeError(Exception):
def __init__(self, container: Any):
super().__init__(f"Invalid container type: {container.__class__}")


class ContainerTypeMismatchError(Exception):
def __init__(self, task_type: TaskType, container: Any):
super().__init__(f"Invalid container type for '{task_type}': {container.__class__}")


class NoWebsocketAvailableError(Exception):
def __init__(self):
super().__init__(
Expand Down
85 changes: 39 additions & 46 deletions silverback/middlewares.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Tuple
from typing import Any

from ape.logging import logger
from ape.types import ContractLog
Expand All @@ -7,27 +7,10 @@
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult

from silverback.persistence import HandlerResult
from silverback.types import SilverbackID, handler_id_block, handler_id_event
from silverback.types import SilverbackID, TaskType
from silverback.utils import hexbytes_dict


def resolve_task(message: TaskiqMessage) -> Tuple[str, Optional[int], Optional[int]]:
block_number = None
log_index = None
task_id = message.task_name

if task_id == "block":
block_number = message.args[0].number
task_id = handler_id_block(block_number)
elif "event" in task_id:
block_number = message.args[0].block_number
log_index = message.args[0].log_index
# TODO: Should standardize on event signature here instead of name in case of overloading
task_id = handler_id_event(message.args[0].contract_address, message.args[0].event_name)

return task_id, block_number, log_index


class SilverbackMiddleware(TaskiqMiddleware, ManagerAccessMixin):
def __init__(self, *args, **kwargs):
def compute_block_time() -> int:
Expand Down Expand Up @@ -66,57 +49,67 @@ def fix_dict(data: dict, recurse_count: int = 0) -> dict:
return message

def _create_label(self, message: TaskiqMessage) -> str:
if message.task_name == "block":
args = f"[block={message.args[0].hash.hex()}]"

elif "event" in message.task_name:
args = f"[txn={message.args[0].transaction_hash},log_index={message.args[0].log_index}]"
if labels_str := ",".join(f"{k}={v}" for k, v in message.labels.items()):
return f"{message.task_name}[{labels_str}]"

else:
args = ""

return f"{message.task_name}{args}"
return message.task_name

def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage:
if message.task_name == "block":
if not (task_type := message.labels.pop("task_type")):
return message # Not a silverback task

try:
task_type = TaskType(task_type)
except ValueError:
return message # Not a silverback task

# Add extra labels for our task to see what their source was
if task_type is TaskType.NEW_BLOCKS:
# NOTE: Necessary because we don't know the exact block class
message.args[0] = self.provider.network.ecosystem.decode_block(
block = message.args[0] = self.provider.network.ecosystem.decode_block(
hexbytes_dict(message.args[0])
)
message.labels["block_number"] = str(block.number)
message.labels["block_hash"] = block.hash.hex()

elif "event" in message.task_name:
elif task_type is TaskType.EVENT_LOG:
# NOTE: Just in case the user doesn't specify type as `ContractLog`
message.args[0] = ContractLog.model_validate(message.args[0])
log = message.args[0] = ContractLog.model_validate(message.args[0])
message.labels["block_number"] = str(log.block_number)
message.labels["transaction_hash"] = log.transaction_hash
message.labels["log_index"] = str(log.log_index)

logger.info(f"{self._create_label(message)} - Started")
logger.debug(f"{self._create_label(message)} - Started")
return message

def post_execute(self, message: TaskiqMessage, result: TaskiqResult):
percentage_time = 100 * (result.execution_time / self.block_time)
logger.info(
f"{self._create_label(message)} "
f"- {result.execution_time:.3f}s ({percentage_time:.1f}%)"
if self.block_time:
percentage_time = 100 * (result.execution_time / self.block_time)
percent_display = f" ({percentage_time:.1f}%)"

else:
percent_display = ""

(logger.error if result.error else logger.success)(
f"{self._create_label(message)} " f"- {result.execution_time:.3f}s{percent_display}"
)

async def post_save(self, message: TaskiqMessage, result: TaskiqResult):
if not self.persistence:
return

handler_id, block_number, log_index = resolve_task(message)

handler_result = HandlerResult.from_taskiq(
self.ident, handler_id, block_number, log_index, result
self.ident,
message.task_name,
message.labels.get("block_number"),
message.labels.get("log_index"),
result,
)

try:
await self.persistence.add_result(handler_result)
except Exception as err:
logger.error(f"Error storing result: {err}")

async def on_error(
self,
message: TaskiqMessage,
result: TaskiqResult,
exception: BaseException,
):
logger.error(f"{message.task_name} - {type(exception).__name__}: {exception}")
# NOTE: Unless stdout is ignored, error traceback appears in stdout, no need for `on_error`
Loading

0 comments on commit 87672ac

Please sign in to comment.