From 0d36142fd6b90268163ab8e62e9a1014f73790ca Mon Sep 17 00:00:00 2001
From: Doggie B <3859395+fubuloubu@users.noreply.github.com>
Date: Mon, 28 Oct 2024 18:45:43 -0400
Subject: [PATCH] feat(runner): add BacktestRunner w/ `silverback test` command

---
 setup.py                  |   1 +
 silverback/__init__.py    |   1 +
 silverback/_cli.py        |  27 +++++++++
 silverback/pytest.py      | 117 ++++++++++++++++++++++++++++++++++++++
 silverback/runner.py      |  64 +++++++++++++++++++++
 tests/backtest_merge.yaml |   5 ++
 6 files changed, 215 insertions(+)
 create mode 100644 silverback/pytest.py
 create mode 100644 tests/backtest_merge.yaml

diff --git a/setup.py b/setup.py
index 7c6393e7..55b2a5c6 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@
     ],
     entry_points={
         "console_scripts": ["silverback=silverback._cli:cli"],
+        "pytest11": ["silverback_test=silverback.pytest"],
     },
     python_requires=">=3.10,<4",
     extras_require=extras_require,
diff --git a/silverback/__init__.py b/silverback/__init__.py
index 1f55c662..75a3e070 100644
--- a/silverback/__init__.py
+++ b/silverback/__init__.py
@@ -22,6 +22,7 @@ def __getattr__(name: str):
 
 __all__ = [
     "StateSnapshot",
+    "BacktestRunner",
     "CircuitBreaker",
     "SilverbackBot",
     "SilverbackException",
diff --git a/silverback/_cli.py b/silverback/_cli.py
index 043ff5e6..3dd83634 100644
--- a/silverback/_cli.py
+++ b/silverback/_cli.py
@@ -1,10 +1,12 @@
 import asyncio
 import os
+import sys
 from datetime import datetime, timedelta, timezone
 from pathlib import Path
 from typing import TYPE_CHECKING, Optional
 
 import click
+import pytest
 import yaml  # type: ignore[import-untyped]
 from ape.cli import (
     AccountAliasPromptChoice,
@@ -13,6 +15,7 @@
     account_option,
     ape_cli_context,
     network_option,
+    verbosity_option,
 )
 from ape.exceptions import Abort, ApeException
 from ape.logging import LogLevel
@@ -171,6 +174,30 @@ def worker(cli_ctx, account, workers, max_exceptions, shutdown_timeout, bot):
     asyncio.run(run_worker(bot.broker, worker_count=workers, shutdown_timeout=shutdown_timeout))
 
 
+@cli.command(
+    section="Local Commands",
+    add_help_option=False,  # NOTE: This allows pass-through to pytest's help
+    short_help="Run bot backtests (`tests/backtest_*.yaml`)",
+    context_settings=dict(ignore_unknown_options=True),
+)
+@ape_cli_context()
+@verbosity_option()
+@network_option(
+    default=os.environ.get("SILVERBACK_NETWORK_CHOICE", "auto"),
+    callback=_network_callback,
+)
+@click.option("--bot", "bots", multiple=True)
+@click.argument("pytest_args", nargs=-1, type=click.UNPROCESSED)
+def test(cli_ctx, network, bots, pytest_args):
+    os.environ["SILVERBACK_FORK_MODE"] = "1"
+
+    return_code = pytest.main([*pytest_args], ["silverback.pytest"])
+
+    if return_code:
+        # only exit with non-zero status to make testing easier
+        sys.exit(return_code)
+
+
 @cli.command(section="Cloud Commands (https://silverback.apeworx.io)")
 @auth_required
 def login(auth: "FiefAuth"):
diff --git a/silverback/pytest.py b/silverback/pytest.py
new file mode 100644
index 00000000..46e2e5c8
--- /dev/null
+++ b/silverback/pytest.py
@@ -0,0 +1,117 @@
+import asyncio
+import os
+from pathlib import Path
+
+import pytest
+import yaml  # type: ignore[import]
+from ape.utils import cached_property
+
+from silverback._importer import import_from_string
+from silverback.exceptions import SilverbackException
+from silverback.runner import BacktestRunner
+
+
+class AssertionViolation(SilverbackException):
+    pass
+
+
+def pytest_collect_file(parent, file_path):
+    if file_path.suffix == ".yaml" and file_path.name.startswith("backtest"):
+        return BacktestFile.from_parent(parent, path=file_path)
+
+
+class BacktestFile(pytest.File):
+    def collect(self):
+        raw = yaml.safe_load(self.path.open())
+        if not (network_triple := raw.get("network")):
+            raise ValueError(f"{self.path} is missing key 'network'.")
+
+        start_block = raw.get("start_block", 0)
+        stop_block = raw.get("stop_block", -1)
+        assertion_checks = raw.get("assertions", {})
+
+        raw_bot_paths = raw.get("bots")
+        if isinstance(raw_bot_paths, list):
+            for bot_path in raw_bot_paths:
+                if ":" in bot_path:
+                    bot_path, bot_name = bot_path.split(":")
+                    bot_path = Path(bot_path)
+                else:
+                    bot_path = Path(bot_path)
+                    bot_name = "bot"
+
+                yield BacktestItem.from_parent(
+                    self,
+                    name=f"{self.name}[{bot_name}]",
+                    file_path=self.path,
+                    bot_path=bot_path,
+                    bot_name=bot_name,
+                    network_triple=network_triple,
+                    start_block=start_block,
+                    stop_block=stop_block,
+                    assertion_checks=assertion_checks,
+                )
+
+        else:
+            if ":" in raw_bot_paths:
+                bot_path, bot_name = raw_bot_paths.split(":")
+                bot_path = Path(bot_path)
+            else:
+                bot_path = Path(raw_bot_paths)
+                bot_name = "bot"
+
+            yield BacktestItem.from_parent(
+                self,
+                name=self.name,
+                file_path=self.path,
+                bot_path=bot_path,
+                bot_name=bot_name,
+                network_triple=network_triple,
+                start_block=start_block,
+                stop_block=stop_block,
+                assertion_checks=assertion_checks,
+            )
+
+
+class BacktestItem(pytest.Item):
+    def __init__(
+        self,
+        *,
+        file_path,
+        bot_path,
+        bot_name,
+        network_triple,
+        start_block,
+        stop_block,
+        assertion_checks,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.file_path = file_path
+        self.bot_path = bot_path
+        self.bot_name = bot_name
+        self.network_triple = network_triple
+        self.start_block = start_block
+        self.stop_block = stop_block
+        self.assertion_checks = assertion_checks
+
+        self.assertion_failures = 0
+        self.overruns = 0
+
+    @cached_property
+    def runner(self):
+        os.environ["SILVERBACK_NETWORK_CHOICE"] = self.network_triple
+        os.environ["PYTHONPATH"] = str(self.bot_path.parent)
+        app = import_from_string(f"{self.bot_path.stem}:{self.bot_name}")
+        return BacktestRunner(app, start_block=self.start_block, stop_block=self.stop_block)
+
+    def check_assertions(self, result: dict):
+        pass
+
+    def runtest(self):
+        asyncio.run(self.runner.run())
+        self.raise_run_status()
+
+    def raise_run_status(self):
+        if self.overruns > 0 or self.assertion_failures > 0:
+            raise AssertionViolation()
diff --git a/silverback/runner.py b/silverback/runner.py
index 46987fa4..b3a69a90 100644
--- a/silverback/runner.py
+++ b/silverback/runner.py
@@ -5,6 +5,7 @@
 from ape.logging import logger
 from ape.utils import ManagerAccessMixin
 from ape_ethereum.ecosystem import keccak
+from click import progressbar
 from ethpm_types import EventABI
 from packaging.specifiers import SpecifierSet
 from packaging.version import Version
@@ -420,3 +421,66 @@ async def _event_task(self, task_data: TaskData):
             await self._checkpoint(last_block_seen=event.block_number)
             await self._handle_task(await event_log_task_kicker.kiq(event))
             await self._checkpoint(last_block_processed=event.block_number)
+
+
+class BacktestRunner(BaseRunner):
+    def __init__(
+        self,
+        app: SilverbackBot,
+        start_block: int,
+        stop_block: int,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(app, *args, **kwargs)
+
+        # NOTE: Takes time to do the data collection
+        with progressbar(
+            chain.blocks.range(start_block, stop_block + 1),
+            length=(stop_block - start_block),
+        ) as blocks:
+            self.blocks = list(blocks)
+
+        logger.info(
+            f"Using {self.__class__.__name__}:"
+            f" num_blocks={stop_block - start_block}"
+            f" max_exceptions={self.max_exceptions}"
+        )
+
+    async def _block_task(self, task_data: TaskData):
+        new_block_task_kicker = self._create_task_kicker(task_data)
+
+        async for block in async_wrap_iter(iter(self.blocks)):
+            await self._checkpoint(last_block_seen=block.number)
+            await self._handle_task(await new_block_task_kicker.kiq(block))
+            await self._checkpoint(last_block_processed=block.number)
+
+    async def _event_task(self, task_data: TaskData):
+        if not (event_signature := task_data.labels.get("event_signature")):
+            raise StartupFailure("No Event Signature provided.")
+
+        event_abi = EventABI.from_signature(event_signature)
+
+        if not (contract_address := task_data.labels.get("contract_address")):
+            raise StartupFailure("Contract instance required.")
+
+        if (
+            not (
+                events := chain.contracts.instance_at(contract_address)._events_.get(event_abi.name)
+            )
+            or len(events) == 0
+        ):
+            raise StartupFailure(
+                "Contract '{contract_address}' does not have event '{event_abi.name}'."
+            )
+
+        event_log_task_kicker = self._create_task_kicker(task_data)
+
+        async for block in async_wrap_iter(iter(self.blocks)):
+            txn_hashes = iter(tx.txn_hash.hex() for tx in block.transactions)
+            receipts = map(chain.get_receipt, txn_hashes)
+            async for logs in async_wrap_iter(map(events[0].from_receipt, receipts)):
+                async for log in async_wrap_iter(iter(logs)):
+                    await self._checkpoint(last_block_seen=log.block_number)
+                    await self._handle_task(await event_log_task_kicker.kiq(log))
+                    await self._checkpoint(last_block_processed=log.block_number)
diff --git a/tests/backtest_merge.yaml b/tests/backtest_merge.yaml
new file mode 100644
index 00000000..67f16b7b
--- /dev/null
+++ b/tests/backtest_merge.yaml
@@ -0,0 +1,5 @@
+bots: ["example"]
+network: "ethereum:mainnet"
+start_block: 15_338_009
+stop_block: 15_338_018
+something_else: blah