Skip to content

Commit

Permalink
feat(runner): add BacktestRunner w/ silverback test command
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Oct 28, 2024
1 parent 6f4b117 commit 2d4ffb5
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
],
entry_points={
"console_scripts": ["silverback=silverback._cli:cli"],
"pytest11": ["silverback_test=silverback.pytest"],
},
python_requires=">=3.10,<4",
extras_require=extras_require,
Expand Down
1 change: 1 addition & 0 deletions silverback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

__all__ = [
"StateSnapshot",
"BacktestRunner",
"CircuitBreaker",
"SilverbackBot",
"SilverbackException",
Expand Down
26 changes: 26 additions & 0 deletions silverback/_cli.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import os
import sys
from datetime import datetime, timedelta, timezone
from pathlib import Path

import click
import pytest
import yaml # type: ignore[import-untyped]
from ape.api import AccountAPI, NetworkAPI
from ape.cli import (
Expand All @@ -12,6 +14,7 @@
account_option,
ape_cli_context,
network_option,
verbosity_option,
)
from ape.contracts import ContractInstance
from ape.exceptions import Abort, ApeException
Expand Down Expand Up @@ -162,6 +165,29 @@ def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, bot):
asyncio.run(run_worker(bot.broker, worker_count=workers, shutdown_timeout=shutdown_timeout))


@cli.command(
section="Local Commands",
add_help_option=False, # NOTE: This allows pass-through to pytest's help
short_help="Launches pytest and runs the tests for an app",
context_settings=dict(ignore_unknown_options=True),
)
@ape_cli_context()
@verbosity_option()
@network_option(default=None, callback=_network_callback)
@click.argument("pytest_args", nargs=-1, type=click.UNPROCESSED)
def test(cli_ctx, network, pytest_args):
if not network:
os.environ["SILVERBACK_NETWORK_CHOICE"] = ":mainnet-fork"

os.environ["SILVERBACK_FORK_MODE"] = "1"

return_code = pytest.main([*pytest_args], ["silverback_test"])

if return_code:
# only exit with non-zero status to make testing easier
sys.exit(return_code)


@cli.command(section="Cloud Commands (https://silverback.apeworx.io)")
@auth_required
def login(auth: FiefAuth):
Expand Down
79 changes: 79 additions & 0 deletions silverback/pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
import os
from pathlib import Path

import pytest
import yaml # type: ignore[import]
from ape.utils import cached_property
from uvicorn.importer import import_from_string

from silverback.exceptions import SilverbackException
from silverback.runner import BacktestRunner


class AssertionViolation(SilverbackException):
pass


def pytest_collect_file(parent, file_path):
if file_path.suffix == ".yaml" and file_path.name.startswith("backtest"):
return BacktestFile.from_parent(parent, path=file_path)


class BacktestFile(pytest.File):
def collect(self):
raw = yaml.safe_load(self.path.open())
yield BacktestItem.from_parent(
self,
name=self.name,
file_path=self.path,
app_path=raw.get("app", os.environ.get("SILVERBACK_APP")),
network_triple=raw.get("network", ""),
start_block=raw.get("start_block", 0),
stop_block=raw.get("stop_block", -1),
assertion_checks=raw.get("assertions", {}),
)


class BacktestItem(pytest.Item):
def __init__(
self,
*,
file_path,
app_path,
network_triple,
start_block,
stop_block,
assertion_checks,
**kwargs,
):
super().__init__(**kwargs)
self.file_path = file_path
self.app_path = app_path or "app.py:app"
self.network_triple = network_triple
self.start_block = start_block
self.stop_block = stop_block
self.assertion_checks = assertion_checks

self.assertion_failures = 0
self.overruns = 0

@cached_property
def runner(self):
app_path, app_name = self.app_path.split(":")
app_path = Path(app_path)
os.environ["SILVERBACK_NETWORK_CHOICE"] = self.network_triple
os.environ["PYTHONPATH"] = str(app_path.parent)
app = import_from_string(f"{app_path.stem}:{app_name}")
return BacktestRunner(app, start_block=self.start_block, stop_block=self.stop_block)

def check_assertions(self, result: dict):
pass

def runtest(self):
asyncio.run(self.runner.run())
self.raise_run_status()

def raise_run_status(self):
if self.overruns > 0 or self.assertion_failures > 0:
raise AssertionViolation()
61 changes: 61 additions & 0 deletions silverback/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ape.logging import logger
from ape.utils import ManagerAccessMixin
from ape_ethereum.ecosystem import keccak
from click import progressbar
from ethpm_types import EventABI
from packaging.specifiers import SpecifierSet
from packaging.version import Version
Expand Down Expand Up @@ -420,3 +421,63 @@ async def _event_task(self, task_data: TaskData):
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_log_task_kicker.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)


class BacktestRunner(BaseRunner):
def __init__(
self,
app: SilverbackBot,
start_block: int,
stop_block: int,
*args,
**kwargs,
):
super().__init__(app, *args, **kwargs)

# NOTE: Takes time to do the data collection
with progressbar(
chain.blocks.range(start_block, stop_block + 1),
length=(stop_block - start_block),
) as blocks:
self.blocks = list(blocks)

logger.info(
f"Using {self.__class__.__name__}:"
f" num_blocks={stop_block - start_block}"
f" max_exceptions={self.max_exceptions}"
)

async def _block_task(self, task_data: TaskData):
new_block_task_kicker = self._create_task_kicker(task_data)

async for block in async_wrap_iter(iter(self.blocks)):
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await new_block_task_kicker.kiq(block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(self, task_data: TaskData):
if not (event_signature := task_data.labels.get("event_signature")):
raise StartupFailure("No Event Signature provided.")

event_abi = EventABI.from_signature(event_signature)

if not (contract_address := task_data.labels.get("contract_address")):
raise StartupFailure("Contract instance required.")

if (
not (
events := chain.contracts.instance_at(contract_address)._events_.get(event_abi.name)
)
or len(events) == 0
):
raise StartupFailure(
"Contract '{contract_address}' does not have event '{event_abi.name}'."
)

event_log_task_kicker = self._create_task_kicker(task_data)

async for block in async_wrap_iter(iter(self.blocks)):
async for log in async_wrap_iter(map(events[0].from_receipt, block.transactions)):
await self._checkpoint(last_block_seen=log.block_number)
await self._handle_task(await event_log_task_kicker.kiq(log))
await self._checkpoint(last_block_processed=log.block_number)
5 changes: 5 additions & 0 deletions tests/backtest_merge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
bot: "example"
network: "ethereum:mainnet-fork"
start_block: 15_338_009
stop_block: 15_338_018
something_else: blah

0 comments on commit 2d4ffb5

Please sign in to comment.