Skip to content

Commit

Permalink
feat(runner): add BacktestRunner w/ silverback test command
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Nov 24, 2024
1 parent e949e04 commit 3cc8392
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
],
entry_points={
"console_scripts": ["silverback=silverback._cli:cli"],
"pytest11": ["silverback_test=silverback.pytest"],
},
python_requires=">=3.10,<4",
extras_require=extras_require,
Expand Down
1 change: 1 addition & 0 deletions silverback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __getattr__(name: str):

__all__ = [
"StateSnapshot",
"BacktestRunner",
"CircuitBreaker",
"SilverbackBot",
"SilverbackException",
Expand Down
27 changes: 27 additions & 0 deletions silverback/_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import click
import pytest
import yaml # type: ignore[import-untyped]
from ape.cli import (
AccountAliasPromptChoice,
Expand All @@ -13,6 +15,7 @@
account_option,
ape_cli_context,
network_option,
verbosity_option,
)
from ape.exceptions import Abort, ApeException
from ape.logging import LogLevel
Expand Down Expand Up @@ -171,6 +174,30 @@ def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, bot):
asyncio.run(run_worker(bot.broker, worker_count=workers, shutdown_timeout=shutdown_timeout))


@cli.command(
section="Local Commands",
add_help_option=False, # NOTE: This allows pass-through to pytest's help
short_help="Run bot backtests (`tests/backtest_*.yaml`)",
context_settings=dict(ignore_unknown_options=True),
)
@ape_cli_context()
@verbosity_option()
@network_option(
default=os.environ.get("SILVERBACK_NETWORK_CHOICE", "auto"),
callback=_network_callback,
)
@click.option("--bot", "bots", multiple=True)
@click.argument("pytest_args", nargs=-1, type=click.UNPROCESSED)
def test(cli_ctx, network, bots, pytest_args):
os.environ["SILVERBACK_FORK_MODE"] = "1"

return_code = pytest.main([*pytest_args], ["silverback_test"])

if return_code:
# only exit with non-zero status to make testing easier
sys.exit(return_code)


@cli.command(section="Cloud Commands (https://silverback.apeworx.io)")
@auth_required
def login(auth: "FiefAuth"):
Expand Down
110 changes: 110 additions & 0 deletions silverback/pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import asyncio
import os
from pathlib import Path

import pytest
import yaml # type: ignore[import]
from ape.utils import cached_property

from silverback._importer import import_from_string
from silverback.exceptions import SilverbackException
from silverback.runner import BacktestRunner


class AssertionViolation(SilverbackException):
pass


def pytest_collect_file(parent, file_path):
if file_path.suffix == ".yaml" and file_path.name.startswith("backtest"):
return BacktestFile.from_parent(parent, path=file_path)


class BacktestFile(pytest.File):
def collect(self):
raw = yaml.safe_load(self.path.open())
raw_bot_path = raw.get("bot")
if isinstance(raw_bot_path, list):
for bot_path in raw_bot_path:
if ":" in bot_path:
bot_path, bot_name = bot_path.split(":")
bot_path = Path(bot_path)
else:
bot_path = Path(bot_path)
bot_name = "bot"

yield BacktestItem.from_parent(
self,
name=self.name,
file_path=self.path,
bot_path=bot_path,
bot_name=bot_name,
network_triple=raw.get("network", ""),
start_block=raw.get("start_block", 0),
stop_block=raw.get("stop_block", -1),
assertion_checks=raw.get("assertions", {}),
)

else:
if ":" in raw_bot_path:
bot_path, bot_name = raw_bot_path.split(":")
bot_path = Path(bot_path)
else:
bot_path = Path(raw_bot_path)
bot_name = "bot"

yield BacktestItem.from_parent(
self,
name=self.name,
file_path=self.path,
bot_path=bot_path,
bot_name=bot_name,
network_triple=raw.get("network", ""),
start_block=raw.get("start_block", 0),
stop_block=raw.get("stop_block", -1),
assertion_checks=raw.get("assertions", {}),
)


class BacktestItem(pytest.Item):
def __init__(
self,
*,
file_path,
bot_path,
bot_name,
network_triple,
start_block,
stop_block,
assertion_checks,
**kwargs,
):
super().__init__(**kwargs)
self.file_path = file_path
self.bot_path = bot_path
self.bot_name = bot_name
self.network_triple = network_triple
self.start_block = start_block
self.stop_block = stop_block
self.assertion_checks = assertion_checks

self.assertion_failures = 0
self.overruns = 0

@cached_property
def runner(self):
os.environ["SILVERBACK_NETWORK_CHOICE"] = self.network_triple
os.environ["PYTHONPATH"] = str(self.bot_path.parent)
app = import_from_string(f"{self.bot_path.stem}:{self.bot_name}")
return BacktestRunner(app, start_block=self.start_block, stop_block=self.stop_block)

def check_assertions(self, result: dict):
pass

def runtest(self):
asyncio.run(self.runner.run())
self.raise_run_status()

def raise_run_status(self):
if self.overruns > 0 or self.assertion_failures > 0:
raise AssertionViolation()
64 changes: 64 additions & 0 deletions silverback/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ape.logging import logger
from ape.utils import ManagerAccessMixin
from ape_ethereum.ecosystem import keccak
from click import progressbar
from ethpm_types import EventABI
from packaging.specifiers import SpecifierSet
from packaging.version import Version
Expand Down Expand Up @@ -420,3 +421,66 @@ async def _event_task(self, task_data: TaskData):
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_log_task_kicker.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)


class BacktestRunner(BaseRunner):
def __init__(
self,
app: SilverbackBot,
start_block: int,
stop_block: int,
*args,
**kwargs,
):
super().__init__(app, *args, **kwargs)

# NOTE: Takes time to do the data collection
with progressbar(
chain.blocks.range(start_block, stop_block + 1),
length=(stop_block - start_block),
) as blocks:
self.blocks = list(blocks)

logger.info(
f"Using {self.__class__.__name__}:"
f" num_blocks={stop_block - start_block}"
f" max_exceptions={self.max_exceptions}"
)

async def _block_task(self, task_data: TaskData):
new_block_task_kicker = self._create_task_kicker(task_data)

async for block in async_wrap_iter(iter(self.blocks)):
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(self, task_data: TaskData):
if not (event_signature := task_data.labels.get("event_signature")):
raise StartupFailure("No Event Signature provided.")

event_abi = EventABI.from_signature(event_signature)

if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

if (
not (
events := chain.contracts.instance_at(contract_address)._events_.get(event_abi.name)
)
or len(events) == 0
):
raise StartupFailure(
"Contract '{contract_address}' does not have event '{event_abi.name}'."
)

event_log_task_kicker = self._create_task_kicker(task_data)

async for block in async_wrap_iter(iter(self.blocks)):
txn_hashes = iter(tx.txn_hash.hex() for tx in block.transactions)
receipts = map(chain.get_receipt, txn_hashes)
async for logs in async_wrap_iter(map(events[0].from_receipt, receipts)):
async for log in async_wrap_iter(iter(logs)):
await self._checkpoint(last_block_seen=log.block_number)
await self._handle_task(await event_log_task_kicker.kiq(log))
await self._checkpoint(last_block_processed=log.block_number)
5 changes: 5 additions & 0 deletions tests/backtest_merge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bot: "example"
network: "ethereum:mainnet-fork"
start_block: 15_338_009
stop_block: 15_338_018
something_else: blah

0 comments on commit 3cc8392

Please sign in to comment.