Skip to content

Commit

Permalink
feat: ability to skip proxy detection (#2470)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jan 27, 2025
1 parent 7bd2810 commit 601e2c4
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 19 deletions.
80 changes: 61 additions & 19 deletions src/ape/managers/_contractscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,22 @@ def _delete_proxy(self, address: AddressType):
def __contains__(self, address: AddressType) -> bool:
return self.get(address) is not None

def cache_deployment(self, contract_instance: ContractInstance):
def cache_deployment(
self,
contract_instance: ContractInstance,
proxy_info: Optional[ProxyInfoAPI] = None,
detect_proxy: bool = True,
):
"""
Cache the given contract instance's type and deployment information.
Args:
contract_instance (:class:`~ape.contracts.base.ContractInstance`): The contract
to cache.
proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to
avoid the potentially expensive look-up.
detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a
proxy.
"""
address = contract_instance.address
contract_type = contract_instance.contract_type # may be a proxy
Expand All @@ -285,24 +294,22 @@ def cache_deployment(self, contract_instance: ContractInstance):
# in case it is needed somewhere. It may get overridden.
self.contract_types.memory[address] = contract_type

if proxy_info := self.provider.network.ecosystem.get_proxy_info(address):
# The user is caching a deployment of a proxy with the target already set.
self.cache_proxy_info(address, proxy_info)
if implementation_contract := self.get(proxy_info.target):
updated_proxy_contract = _get_combined_contract_type(
contract_type, proxy_info, implementation_contract
)
self.contract_types[address] = updated_proxy_contract
if proxy_info:
# Was given proxy info.
self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance)

# Use this contract type in the user's contract instance.
contract_instance.contract_type = updated_proxy_contract
elif detect_proxy:
# Proxy info was not provided. Use the connected ecosystem to figure it out.
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address):
# The user is caching a deployment of a proxy with the target already set.
self._cache_proxy_contract(address, proxy_info, contract_type, contract_instance)

else:
# No implementation yet. Just cache proxy.
# Cache as normal.
self.contract_types[address] = contract_type

else:
# Regular contract. Cache normally.
# Cache as normal; do not do expensive proxy detection.
self.contract_types[address] = contract_type

# Cache the deployment now.
Expand All @@ -312,6 +319,26 @@ def cache_deployment(self, contract_instance: ContractInstance):

return contract_type

def _cache_proxy_contract(
self,
address: AddressType,
proxy_info: ProxyInfoAPI,
contract_type: ContractType,
contract_instance: ContractInstance,
):
self.cache_proxy_info(address, proxy_info)
if implementation_contract := self.get(proxy_info.target):
updated_proxy_contract = _get_combined_contract_type(
contract_type, proxy_info, implementation_contract
)
self.contract_types[address] = updated_proxy_contract

# Use this contract type in the user's contract instance.
contract_instance.contract_type = updated_proxy_contract
else:
# No implementation yet. Just cache proxy.
self.contract_types[address] = contract_type

def cache_proxy_info(self, address: AddressType, proxy_info: ProxyInfoAPI):
"""
Cache proxy info for a particular address, useful for plugins adding already
Expand Down Expand Up @@ -492,6 +519,8 @@ def get(
address: AddressType,
default: Optional[ContractType] = None,
fetch_from_explorer: bool = True,
proxy_info: Optional[ProxyInfoAPI] = None,
detect_proxy: bool = True,
) -> Optional[ContractType]:
"""
Get a contract type by address.
Expand All @@ -506,6 +535,9 @@ def get(
fetch_from_explorer (bool): Set to ``False`` to avoid fetching from an
explorer. Defaults to ``True``. Only fetches if it needs to (uses disk
& memory caching otherwise).
proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known,
to avoid the potentially expensive look-up.
detect_proxy (bool): Set to ``False`` to avoid detecting if it is a proxy.
Returns:
Optional[ContractType]: The contract type if it was able to get one,
Expand All @@ -531,13 +563,14 @@ def get(

else:
# Contract is not cached yet. Check broader sources, such as an explorer.
# First, detect if this is a proxy.
if not (proxy_info := self.proxy_infos[address_key]):
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
self.proxy_infos[address_key] = proxy_info
if not proxy_info and detect_proxy:
# Proxy info not provided. Attempt to detect.
if not (proxy_info := self.proxy_infos[address_key]):
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
self.proxy_infos[address_key] = proxy_info

if proxy_info:
# Contract is a proxy.
# Contract is a proxy (either was detected or provided).
implementation_contract_type = self.get(proxy_info.target, default=default)
proxy_contract_type = (
self._get_contract_type_from_explorer(address_key)
Expand Down Expand Up @@ -594,6 +627,8 @@ def instance_at(
txn_hash: Optional[Union[str, "HexBytes"]] = None,
abi: Optional[Union[list[ABI], dict, str, Path]] = None,
fetch_from_explorer: bool = True,
proxy_info: Optional[ProxyInfoAPI] = None,
detect_proxy: bool = True,
) -> ContractInstance:
"""
Get a contract at the given address. If the contract type of the contract is known,
Expand All @@ -618,6 +653,9 @@ def instance_at(
fetch_from_explorer (bool): Set to ``False`` to avoid fetching from the explorer.
Defaults to ``True``. Won't fetch unless it needs to (uses disk & memory caching
first).
proxy_info (Optional[ProxyInfoAPI]): Pass in the proxy info, if it is known, to avoid
the potentially expensive look-up.
detect_proxy (bool): Set to ``False`` to avoid detecting if the contract is a proxy.
Returns:
:class:`~ape.contracts.base.ContractInstance`
Expand All @@ -640,7 +678,11 @@ def instance_at(
try:
# Always attempt to get an existing contract type to update caches
contract_type = self.get(
contract_address, default=contract_type, fetch_from_explorer=fetch_from_explorer
contract_address,
default=contract_type,
fetch_from_explorer=fetch_from_explorer,
proxy_info=proxy_info,
detect_proxy=detect_proxy,
)
except Exception as err:
if contract_type or abi:
Expand Down
39 changes: 39 additions & 0 deletions tests/functional/test_contracts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,45 @@ def test_instance_at_use_abi(chain, solidity_fallback_contract, owner):
assert instance2.contract_type.abi == instance.contract_type.abi


def test_instance_at_provide_proxy(mocker, chain, vyper_contract_instance, owner):
address = vyper_contract_instance.address
container = _make_minimal_proxy(address=address.lower())
proxy = container.deploy(sender=owner)
proxy_info = chain.contracts.proxy_infos[proxy.address]

del chain.contracts[proxy.address]

proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type")

with pytest.raises(ContractNotFoundError):
# This just fails because we deleted it from the cache so Ape no
# longer knows what the contract type is. That is fine for this test!
chain.contracts.instance_at(proxy.address, proxy_info=proxy_info)

# The real test: we check the spy to ensure we never attempted to look up
# the proxy info for the given address to `instance_at()`.
for call in proxy_detection_spy.call_args_list:
for arg in call[0]:
assert proxy.address != arg


def test_instance_at_skip_proxy(mocker, chain, vyper_contract_instance, owner):
address = vyper_contract_instance.address
del chain.contracts[address]
proxy_detection_spy = mocker.spy(chain.contracts.proxy_infos, "get_type")

with pytest.raises(ContractNotFoundError):
# This just fails because we deleted it from the cache so Ape no
# longer knows what the contract type is. That is fine for this test!
chain.contracts.instance_at(address, detect_proxy=False)

# The real test: we check the spy to ensure we never attempted to look up
# the proxy info for the given address to `instance_at()`.
for call in proxy_detection_spy.call_args_list:
for arg in call[0]:
assert address != arg


def test_cache_deployment_live_network(
chain,
vyper_contract_instance,
Expand Down

0 comments on commit 601e2c4

Please sign in to comment.