diff --git a/src/ape/managers/_contractscache.py b/src/ape/managers/_contractscache.py index fe5677bc3e..250775ec13 100644 --- a/src/ape/managers/_contractscache.py +++ b/src/ape/managers/_contractscache.py @@ -270,13 +270,22 @@ def _delete_proxy(self, address: AddressType): def __contains__(self, address: AddressType) -> bool: return self.get(address) is not None - def cache_deployment(self, contract_instance: ContractInstance): + def cache_deployment( + self, + contract_instance: ContractInstance, + proxy_info: Optional[ProxyInfoAPI] = None, + detect_proxy: bool = True, + ): """ Cache the given contract instance's type and deployment information. Args: contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract to cache. + proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to + avoid the potentially expensive look-up. + detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a + proxy. """ address = contract_instance.address contract_type = contract_instance.contract_type # may be a proxy @@ -285,24 +294,22 @@ def cache_deployment(self, contract_instance: ContractInstance): # in case it is needed somewhere. It may get overridden. self.contract_types.memory[address] = contract_type - if proxy_info := self.provider.network.ecosystem.get_proxy_info(address): - # The user is caching a deployment of a proxy with the target already set. - self.cache_proxy_info(address, proxy_info) - if implementation_contract := self.get(proxy_info.target): - updated_proxy_contract = _get_combined_contract_type( - contract_type, proxy_info, implementation_contract - ) - self.contract_types[address] = updated_proxy_contract + if proxy_info: + # Was given proxy info. + self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance) - # Use this contract type in the user's contract instance. - contract_instance.contract_type = updated_proxy_contract + elif detect_proxy: + # Proxy info was not provided. Use the connected ecosystem to figure it out. + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address): + # The user is caching a deployment of a proxy with the target already set. + self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance) else: - # No implementation yet. Just cache proxy. + # Cache as normal. self.contract_types[address] = contract_type else: - # Regular contract. Cache normally. + # Cache as normal; do not do expensive proxy detection. self.contract_types[address] = contract_type # Cache the deployment now. @@ -312,6 +319,26 @@ def cache_deployment(self, contract_instance: ContractInstance): return contract_type + def _cache_proxy_contract( + self, + address: AddressType, + proxy_info: ProxyInfoAPI, + contract_type: ContractType, + contract_instance: ContractInstance, + ): + self.cache_proxy_info(address, proxy_info) + if implementation_contract := self.get(proxy_info.target): + updated_proxy_contract = _get_combined_contract_type( + contract_type, proxy_info, implementation_contract + ) + self.contract_types[address] = updated_proxy_contract + + # Use this contract type in the user's contract instance. + contract_instance.contract_type = updated_proxy_contract + else: + # No implementation yet. Just cache proxy. + self.contract_types[address] = contract_type + def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI): """ Cache proxy info for a particular address, useful for plugins adding already @@ -492,6 +519,8 @@ def get( address: AddressType, default: Optional[ContractType] = None, fetch_from_explorer: bool = True, + proxy_info: Optional[ProxyInfoAPI] = None, + detect_proxy: bool = True, ) -> Optional[ContractType]: """ Get a contract type by address. @@ -506,6 +535,9 @@ def get( fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an explorer. Defaults to ``True``. Only fetches if it needs to (uses disk & memory caching otherwise). + proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, + to avoid the potentially expensive look-up. + detect_proxy (bool): Set to ``False`` to avoid detecting if it is a proxy. Returns: Optional[ContractType]: The contract type if it was able to get one, @@ -531,13 +563,14 @@ def get( else: # Contract is not cached yet. Check broader sources, such as an explorer. - # First, detect if this is a proxy. - if not (proxy_info := self.proxy_infos[address_key]): - if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key): - self.proxy_infos[address_key] = proxy_info + if not proxy_info and detect_proxy: + # Proxy info not provided. Attempt to detect. + if not (proxy_info := self.proxy_infos[address_key]): + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key): + self.proxy_infos[address_key] = proxy_info if proxy_info: - # Contract is a proxy. + # Contract is a proxy (either was detected or provided). implementation_contract_type = self.get(proxy_info.target, default=default) proxy_contract_type = ( self._get_contract_type_from_explorer(address_key) @@ -594,6 +627,8 @@ def instance_at( txn_hash: Optional[Union[str, "HexBytes"]] = None, abi: Optional[Union[list[ABI], dict, str, Path]] = None, fetch_from_explorer: bool = True, + proxy_info: Optional[ProxyInfoAPI] = None, + detect_proxy: bool = True, ) -> ContractInstance: """ Get a contract at the given address. If the contract type of the contract is known, @@ -618,6 +653,9 @@ def instance_at( fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer. Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching first). + proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to avoid + the potentially expensive look-up. + detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a proxy. Returns: :class:`~ape.contracts.base.ContractInstance` @@ -640,7 +678,11 @@ def instance_at( try: # Always attempt to get an existing contract type to update caches contract_type = self.get( - contract_address, default=contract_type, fetch_from_explorer=fetch_from_explorer + contract_address, + default=contract_type, + fetch_from_explorer=fetch_from_explorer, + proxy_info=proxy_info, + detect_proxy=detect_proxy, ) except Exception as err: if contract_type or abi: diff --git a/tests/functional/test_contracts_cache.py b/tests/functional/test_contracts_cache.py index 3b1c01a4a7..0342080322 100644 --- a/tests/functional/test_contracts_cache.py +++ b/tests/functional/test_contracts_cache.py @@ -120,6 +120,45 @@ def test_instance_at_use_abi(chain, solidity_fallback_contract, owner): assert instance2.contract_type.abi == instance.contract_type.abi +def test_instance_at_provide_proxy(mocker, chain, vyper_contract_instance, owner): + address = vyper_contract_instance.address + container = _make_minimal_proxy(address=address.lower()) + proxy = container.deploy(sender=owner) + proxy_info = chain.contracts.proxy_infos[proxy.address] + + del chain.contracts[proxy.address] + + proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type") + + with pytest.raises(ContractNotFoundError): + # This just fails because we deleted it from the cache so Ape no + # longer knows what the contract type is. That is fine for this test! + chain.contracts.instance_at(proxy.address, proxy_info=proxy_info) + + # The real test: we check the spy to ensure we never attempted to look up + # the proxy info for the given address to `instance_at()`. + for call in proxy_detection_spy.call_args_list: + for arg in call[0]: + assert proxy.address != arg + + +def test_instance_at_skip_proxy(mocker, chain, vyper_contract_instance, owner): + address = vyper_contract_instance.address + del chain.contracts[address] + proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type") + + with pytest.raises(ContractNotFoundError): + # This just fails because we deleted it from the cache so Ape no + # longer knows what the contract type is. That is fine for this test! + chain.contracts.instance_at(address, detect_proxy=False) + + # The real test: we check the spy to ensure we never attempted to look up + # the proxy info for the given address to `instance_at()`. + for call in proxy_detection_spy.call_args_list: + for arg in call[0]: + assert address != arg + + def test_cache_deployment_live_network( chain, vyper_contract_instance,