From 601e2c461b01b50330dd0ee5497b5c63974403c9 Mon Sep 17 00:00:00 2001
From: antazoey <jules@apeworx.io>
Date: Mon, 27 Jan 2025 16:22:12 -0600
Subject: [PATCH] feat: ability to skip proxy detection (#2470)

---
 src/ape/managers/_contractscache.py      | 80 ++++++++++++++++++------
 tests/functional/test_contracts_cache.py | 39 ++++++++++++
 2 files changed, 100 insertions(+), 19 deletions(-)

diff --git a/src/ape/managers/_contractscache.py b/src/ape/managers/_contractscache.py
index fe5677bc3e..250775ec13 100644
--- a/src/ape/managers/_contractscache.py
+++ b/src/ape/managers/_contractscache.py
@@ -270,13 +270,22 @@ def _delete_proxy(self, address: AddressType):
     def __contains__(self, address: AddressType) -> bool:
         return self.get(address) is not None
 
-    def cache_deployment(self, contract_instance: ContractInstance):
+    def cache_deployment(
+        self,
+        contract_instance: ContractInstance,
+        proxy_info: Optional[ProxyInfoAPI] = None,
+        detect_proxy: bool = True,
+    ):
         """
         Cache the given contract instance's type and deployment information.
 
         Args:
             contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract
               to cache.
+            proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to
+              avoid the potentially expensive look-up.
+            detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a
+              proxy.
         """
         address = contract_instance.address
         contract_type = contract_instance.contract_type  # may be a proxy
@@ -285,24 +294,22 @@ def cache_deployment(self, contract_instance: ContractInstance):
         # in case it is needed somewhere. It may get overridden.
         self.contract_types.memory[address] = contract_type
 
-        if proxy_info := self.provider.network.ecosystem.get_proxy_info(address):
-            # The user is caching a deployment of a proxy with the target already set.
-            self.cache_proxy_info(address, proxy_info)
-            if implementation_contract := self.get(proxy_info.target):
-                updated_proxy_contract = _get_combined_contract_type(
-                    contract_type, proxy_info, implementation_contract
-                )
-                self.contract_types[address] = updated_proxy_contract
+        if proxy_info:
+            # Was given proxy info.
+            self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance)
 
-                # Use this contract type in the user's contract instance.
-                contract_instance.contract_type = updated_proxy_contract
+        elif detect_proxy:
+            # Proxy info was not provided. Use the connected ecosystem to figure it out.
+            if proxy_info := self.provider.network.ecosystem.get_proxy_info(address):
+                # The user is caching a deployment of a proxy with the target already set.
+                self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance)
 
             else:
-                # No implementation yet. Just cache proxy.
+                # Cache as normal.
                 self.contract_types[address] = contract_type
 
         else:
-            # Regular contract. Cache normally.
+            # Cache as normal; do not do expensive proxy detection.
             self.contract_types[address] = contract_type
 
         # Cache the deployment now.
@@ -312,6 +319,26 @@ def cache_deployment(self, contract_instance: ContractInstance):
 
         return contract_type
 
+    def _cache_proxy_contract(
+        self,
+        address: AddressType,
+        proxy_info: ProxyInfoAPI,
+        contract_type: ContractType,
+        contract_instance: ContractInstance,
+    ):
+        self.cache_proxy_info(address, proxy_info)
+        if implementation_contract := self.get(proxy_info.target):
+            updated_proxy_contract = _get_combined_contract_type(
+                contract_type, proxy_info, implementation_contract
+            )
+            self.contract_types[address] = updated_proxy_contract
+
+            # Use this contract type in the user's contract instance.
+            contract_instance.contract_type = updated_proxy_contract
+        else:
+            # No implementation yet. Just cache proxy.
+            self.contract_types[address] = contract_type
+
     def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI):
         """
         Cache proxy info for a particular address, useful for plugins adding already
@@ -492,6 +519,8 @@ def get(
         address: AddressType,
         default: Optional[ContractType] = None,
         fetch_from_explorer: bool = True,
+        proxy_info: Optional[ProxyInfoAPI] = None,
+        detect_proxy: bool = True,
     ) -> Optional[ContractType]:
         """
         Get a contract type by address.
@@ -506,6 +535,9 @@ def get(
             fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an
               explorer. Defaults to ``True``. Only fetches if it needs to (uses disk
               & memory caching otherwise).
+            proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known,
+              to avoid the potentially expensive look-up.
+            detect_proxy (bool): Set to ``False`` to avoid detecting if it is a proxy.
 
         Returns:
             Optional[ContractType]: The contract type if it was able to get one,
@@ -531,13 +563,14 @@ def get(
 
         else:
             # Contract is not cached yet. Check broader sources, such as an explorer.
-            # First, detect if this is a proxy.
-            if not (proxy_info := self.proxy_infos[address_key]):
-                if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
-                    self.proxy_infos[address_key] = proxy_info
+            if not proxy_info and detect_proxy:
+                # Proxy info not provided. Attempt to detect.
+                if not (proxy_info := self.proxy_infos[address_key]):
+                    if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
+                        self.proxy_infos[address_key] = proxy_info
 
             if proxy_info:
-                # Contract is a proxy.
+                # Contract is a proxy (either was detected or provided).
                 implementation_contract_type = self.get(proxy_info.target, default=default)
                 proxy_contract_type = (
                     self._get_contract_type_from_explorer(address_key)
@@ -594,6 +627,8 @@ def instance_at(
         txn_hash: Optional[Union[str, "HexBytes"]] = None,
         abi: Optional[Union[list[ABI], dict, str, Path]] = None,
         fetch_from_explorer: bool = True,
+        proxy_info: Optional[ProxyInfoAPI] = None,
+        detect_proxy: bool = True,
     ) -> ContractInstance:
         """
         Get a contract at the given address. If the contract type of the contract is known,
@@ -618,6 +653,9 @@ def instance_at(
             fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer.
               Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching
               first).
+            proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to avoid
+              the potentially expensive look-up.
+            detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a proxy.
 
         Returns:
             :class:`~ape.contracts.base.ContractInstance`
@@ -640,7 +678,11 @@ def instance_at(
         try:
             # Always attempt to get an existing contract type to update caches
             contract_type = self.get(
-                contract_address, default=contract_type, fetch_from_explorer=fetch_from_explorer
+                contract_address,
+                default=contract_type,
+                fetch_from_explorer=fetch_from_explorer,
+                proxy_info=proxy_info,
+                detect_proxy=detect_proxy,
             )
         except Exception as err:
             if contract_type or abi:
diff --git a/tests/functional/test_contracts_cache.py b/tests/functional/test_contracts_cache.py
index 3b1c01a4a7..0342080322 100644
--- a/tests/functional/test_contracts_cache.py
+++ b/tests/functional/test_contracts_cache.py
@@ -120,6 +120,45 @@ def test_instance_at_use_abi(chain, solidity_fallback_contract, owner):
     assert instance2.contract_type.abi == instance.contract_type.abi
 
 
+def test_instance_at_provide_proxy(mocker, chain, vyper_contract_instance, owner):
+    address = vyper_contract_instance.address
+    container = _make_minimal_proxy(address=address.lower())
+    proxy = container.deploy(sender=owner)
+    proxy_info = chain.contracts.proxy_infos[proxy.address]
+
+    del chain.contracts[proxy.address]
+
+    proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type")
+
+    with pytest.raises(ContractNotFoundError):
+        # This just fails because we deleted it from the cache so Ape no
+        # longer knows what the contract type is. That is fine for this test!
+        chain.contracts.instance_at(proxy.address, proxy_info=proxy_info)
+
+    # The real test: we check the spy to ensure we never attempted to look up
+    # the proxy info for the given address to `instance_at()`.
+    for call in proxy_detection_spy.call_args_list:
+        for arg in call[0]:
+            assert proxy.address != arg
+
+
+def test_instance_at_skip_proxy(mocker, chain, vyper_contract_instance, owner):
+    address = vyper_contract_instance.address
+    del chain.contracts[address]
+    proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type")
+
+    with pytest.raises(ContractNotFoundError):
+        # This just fails because we deleted it from the cache so Ape no
+        # longer knows what the contract type is. That is fine for this test!
+        chain.contracts.instance_at(address, detect_proxy=False)
+
+    # The real test: we check the spy to ensure we never attempted to look up
+    # the proxy info for the given address to `instance_at()`.
+    for call in proxy_detection_spy.call_args_list:
+        for arg in call[0]:
+            assert address != arg
+
+
 def test_cache_deployment_live_network(
     chain,
     vyper_contract_instance,