From 6512afc7857afe1033a4cbedf40977660304ba12 Mon Sep 17 00:00:00 2001
From: Juliya Smith <jules@apeworx.io>
Date: Fri, 6 Dec 2024 13:21:50 -0600
Subject: [PATCH] test: fix tests

---
 src/ape/pytest/fixtures.py        |  19 +++-
 src/ape_test/provider.py          |  21 +++-
 tests/functional/test_fixtures.py | 156 ++++++++++++++++++++----------
 3 files changed, 138 insertions(+), 58 deletions(-)

diff --git a/src/ape/pytest/fixtures.py b/src/ape/pytest/fixtures.py
index bfe369e61a..56c67dd464 100644
--- a/src/ape/pytest/fixtures.py
+++ b/src/ape/pytest/fixtures.py
@@ -549,10 +549,16 @@ class IsolationManager(ManagerAccessMixin):
     supported: bool = True
     snapshots: SnapshotRegistry = SnapshotRegistry()
 
-    def __init__(self, config_wrapper: "ConfigWrapper", receipt_capture: "ReceiptCapture"):
+    def __init__(
+        self,
+        config_wrapper: "ConfigWrapper",
+        receipt_capture: "ReceiptCapture",
+        chain_snapshots: Optional[dict] = None,
+    ):
         self.config_wrapper = config_wrapper
         self.receipt_capture = receipt_capture
         self._records: list[str] = []
+        self._chain_snapshots = chain_snapshots
 
     @cached_property
     def _track_transactions(self) -> bool:
@@ -562,6 +568,10 @@ def _track_transactions(self) -> bool:
             and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
         )
 
+    @property
+    def chain_snapshots(self) -> dict:
+        return self._chain_snapshots or self.chain_manager._snapshots
+
     def get_snapshot(self, scope: Scope) -> Snapshot:
         return self.snapshots[scope]
 
@@ -634,7 +644,7 @@ def restore(self, scope: Scope):
         if snapshot_id is None:
             return
 
-        elif snapshot_id not in self.chain_manager._snapshots[self.provider.chain_id]:
+        elif snapshot_id not in self.chain_snapshots[self.provider.chain_id]:
             # Still clear out.
             self.snapshots.clear_snapshot_id(scope)
             return
@@ -643,7 +653,7 @@ def restore(self, scope: Scope):
             self._records.append(f"restoring '{scope.name.upper()}'")
 
         try:
-            self.chain_manager.restore(snapshot_id)
+            self._restore(snapshot_id)
         except NotImplementedError:
             logger.warning(
                 "The connected provider does not support snapshotting. "
@@ -654,6 +664,9 @@ def restore(self, scope: Scope):
 
         self.snapshots.clear_snapshot_id(scope)
 
+    def _restore(self, snapshot_id: "SnapshotID"):
+        self.chain_manager.restore(snapshot_id)
+
     def show_records(self):
         if not self._records:
             return
diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py
index ca3fee0b32..26aec6f169 100644
--- a/src/ape_test/provider.py
+++ b/src/ape_test/provider.py
@@ -107,19 +107,34 @@ def __init__(
     ):
         self.config = config
         self.chain_id = chain_id
+        self._backend = backend
         self._ethereum_tester = None if backend is None else EthereumTester(backend)
 
     @property
     def ethereum_tester(self) -> EthereumTester:
         if self._ethereum_tester is None:
-            backend = ApeEVMBackend(self.config)
-            self._ethereum_tester = EthereumTester(backend)
+            self._backend = ApeEVMBackend(self.config)
+            self._ethereum_tester = EthereumTester(self._backend)
 
         return self._ethereum_tester
 
     @ethereum_tester.setter
     def ethereum_tester(self, value):
         self._ethereum_tester = value
+        self._backend = value.backend
+
+    @property
+    def backend(self) -> "ApeEVMBackend":
+        if self._backend is None:
+            self._backend = ApeEVMBackend(self.config)
+            self._ethereum_tester = EthereumTester(self._backend)
+
+        return self._backend
+
+    @backend.setter
+    def backend(self, value):
+        self._backend = value
+        self._ethereum_tester = EthereumTester(self._backend)
 
     @cached_property
     def api_endpoints(self) -> dict:  # type: ignore
@@ -143,7 +158,7 @@ def config(self) -> "ApeTestConfig":  # type: ignore
 
     @property
     def evm_backend(self) -> ApeEVMBackend:
-        return self.tester.ethereum_tester.backend
+        return self.tester.backend
 
     @cached_property
     def tester(self) -> ApeTester:
diff --git a/tests/functional/test_fixtures.py b/tests/functional/test_fixtures.py
index 3ab36bdf86..04bfeb4aa7 100644
--- a/tests/functional/test_fixtures.py
+++ b/tests/functional/test_fixtures.py
@@ -1,8 +1,11 @@
+from typing import Optional
+
 import pytest
 
 from ape.exceptions import BlockNotFoundError
 from ape.pytest.fixtures import IsolationManager, PytestApeFixtures
 from ape.pytest.utils import Scope
+from ape.types.vm import SnapshotID
 
 
 @pytest.fixture
@@ -35,24 +38,6 @@ def mock_evm(mocker):
     return mocker.MagicMock()
 
 
-@pytest.fixture
-def use_mock_provider(networks, mock_provider, mock_evm):
-    orig_provider = networks.active_provider
-    mock_provider._web3.eth.get_block.side_effect = orig_provider._web3.eth.get_block
-    networks.active_provider = mock_provider
-    orig_backend = mock_provider._evm_backend
-
-    # Ensure functional isolation still uses snapshot.
-    mock_evm.take_snapshot.side_effect = orig_backend.take_snapshot
-
-    try:
-        mock_provider._evm_backend = mock_evm
-        yield mock_provider
-    finally:
-        mock_provider._evm_backend = orig_backend
-        networks.active_provider = orig_provider
-
-
 def test_isolation(isolation, receipt_capture):
     # Set up receipt capture to fail on __exit__
     # AFTER the yield statement. There was a bug
@@ -97,65 +82,132 @@ def test_isolation_restore_not_implemented(mocker, networks, fixtures):
 
 
 @pytest.mark.parametrize("snapshot_id", (0, 1, "123"))
-def test_isolation_snapshot_id_types(snapshot_id, use_mock_provider, fixtures, mock_evm):
-    mock_evm.take_snapshot.side_effect = lambda: snapshot_id
-    isolation_context = fixtures.isolation_manager.isolation(Scope.FUNCTION)
+def test_isolation_snapshot_id_types(snapshot_id, fixtures):
+    class IsolationManagerWithCustomSnapshot(IsolationManager):
+        take_call_count = 0
+        restore_call_count = 0
+        restore_called_with = []
+
+        def take_snapshot(self) -> Optional[SnapshotID]:
+            self.take_call_count += 1
+            return snapshot_id
+
+        def restore(self, scope: Scope):
+            self.restore_call_count += 1
+            self.restore_called_with.append(scope)
+
+    isolation_manager = IsolationManagerWithCustomSnapshot(
+        fixtures.isolation_manager.config_wrapper,
+        fixtures.isolation_manager.receipt_capture,
+    )
+    isolation_context = isolation_manager.isolation(Scope.FUNCTION)
     next(isolation_context)  # Enter.
-    assert mock_evm.take_snapshot.call_count == 1
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert isolation_manager.take_call_count == 1
+    assert isolation_manager.restore_call_count == 0
     next(isolation_context, None)  # Exit.
-    mock_evm.revert_to_snapshot.assert_called_once_with(snapshot_id)
+    assert isolation_manager.restore_called_with == [Scope.FUNCTION]
+
+
+def test_isolation_when_snapshot_fails_avoids_restore(fixtures):
+    class IsolationManagerFailingAtSnapshotting(IsolationManager):
+        take_called = False
+        restore_called = False
 
+        def take_snapshot(self) -> Optional["SnapshotID"]:
+            self.take_called = True
+            raise NotImplementedError()
 
-def test_isolation_when_snapshot_fails_avoids_restore(use_mock_provider, fixtures, mock_evm):
-    mock_evm.take_snapshot.side_effect = NotImplementedError
-    isolation_context = fixtures.isolation_manager.isolation(Scope.FUNCTION)
+        def restore(self, scope: Scope):
+            self.restore_called = True
+
+    isolation_manager = IsolationManagerFailingAtSnapshotting(
+        fixtures.isolation_manager.config_wrapper,
+        fixtures.isolation_manager.receipt_capture,
+    )
+    isolation_context = isolation_manager.isolation(Scope.FUNCTION)
     next(isolation_context)  # Enter.
-    assert mock_evm.take_snapshot.call_count == 1
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert isolation_manager.take_called
+    assert not isolation_manager.restore_called
     next(isolation_context, None)  # Exit.
     # It doesn't even try!
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert not isolation_manager.restore_called
+
+
+def test_isolation_restore_fails_avoids_snapshot_next_time(fixtures):
+    chain_snapshots = {}
+
+    class IsolationManagerFailingAtRestoring(IsolationManager):
+        take_called = False
+        restore_called = False
 
+        def take_snapshot(self) -> Optional["SnapshotID"]:
+            self.take_called = True
+            chain_snapshots[self.provider.chain_id] = ["123"]
+            return "123"
 
-def test_isolation_restore_fails_avoids_snapshot_next_time(
-    networks, use_mock_provider, fixtures, mock_evm
-):
-    mock_evm.take_snapshot.return_value = 123
-    mock_evm.revert_to_snapshot.side_effect = NotImplementedError
-    isolation_context = fixtures.isolation_manager.isolation(Scope.FUNCTION)
+        def _restore(self, snapshot_id: SnapshotID):
+            self.restore_called = True
+            raise NotImplementedError()
+
+        def reset_mock(self):
+            self.take_called = False
+            self.restore_called = False
+
+    isolation_manager = IsolationManagerFailingAtRestoring(
+        fixtures.isolation_manager.config_wrapper,
+        fixtures.isolation_manager.receipt_capture,
+        chain_snapshots=chain_snapshots,
+    )
+    isolation_context = isolation_manager.isolation(Scope.FUNCTION)
     next(isolation_context)  # Enter.
     # Snapshot works, we get this far.
-    assert mock_evm.take_snapshot.call_count == 1
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert isolation_manager.take_called
+    assert not isolation_manager.restore_called
 
-    # At this point, it is realized snapshotting is no-go.
-    mock_evm.take_snapshot.reset_mock()
+    # At this point, it realized snapshotting is no-go.
     next(isolation_context, None)  # Exit.
-    isolation_context = fixtures.isolation_manager.isolation(Scope.FUNCTION)
+    assert isolation_manager.restore_called
+
+    isolation_manager.reset_mock()
+    isolation_context = isolation_manager.isolation(Scope.FUNCTION)
     next(isolation_context)  # Enter again.
+
     # This time, snapshotting is NOT attempted.
-    assert mock_evm.take_snapshot.call_count == 0
+    assert not isolation_manager.take_called
+    assert not isolation_manager.restore_called
 
 
-def test_isolation_supported_flag_set_after_successful_snapshot(
-    use_mock_provider, fixtures, mock_evm
-):
+def test_isolation_supported_flag_set_after_successful_snapshot(fixtures):
     """
     Testing the unusual case where `.supported` was changed manually after
     a successful snapshot and before the restore attempt.
     """
-    mock_evm.take_snapshot.return_value = 123
-    isolation_context = fixtures.isolation_manager.isolation(Scope.FUNCTION)
+
+    class CustomIsolationManager(IsolationManager):
+        take_called = False
+        restore_called = False
+
+        def take_snapshot(self) -> Optional["SnapshotID"]:
+            self.take_called = True
+            return 123
+
+        def restore(self, scope: Scope):
+            self.restore_called = True
+
+    isolation_manager = CustomIsolationManager(
+        fixtures.isolation_manager.config_wrapper,
+        fixtures.isolation_manager.receipt_capture,
+    )
+    isolation_context = isolation_manager.isolation(Scope.FUNCTION)
     next(isolation_context)  # Enter.
-    assert mock_evm.take_snapshot.call_count == 1
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert isolation_manager.take_called
+    assert not isolation_manager.restore_called
 
     # HACK: Change the flag manually to show it will avoid
     #   the restore.
-    fixtures.isolation_manager.supported = False
-
+    isolation_manager.supported = False
+    isolation_manager.take_called = False  # Reset
     next(isolation_context, None)  # Exit.
     # Even though snapshotting worked, the flag was changed,
     # and so the restore never gets attempted.
-    assert mock_evm.revert_to_snapshot.call_count == 0
+    assert not isolation_manager.take_called