From fa2ba27fb48af1f38f30abdd749a0ed0081bbe58 Mon Sep 17 00:00:00 2001 From: antazoey Date: Thu, 22 Aug 2024 09:38:58 -0500 Subject: [PATCH] fix: issues with `CurrencyComparableValue` appearing `ContractLog.event_arguments` & on Pydantic models (#2221) --- src/ape/types/__init__.py | 48 +++++++++++++++- tests/functional/geth/test_provider.py | 3 + tests/functional/test_contract_event.py | 21 ++++++- tests/functional/test_types.py | 76 +++++++++++++++++++++---- 4 files changed, 136 insertions(+), 12 deletions(-) diff --git a/src/ape/types/__init__.py b/src/ape/types/__init__.py index 489f5d2729..066475e96f 100644 --- a/src/ape/types/__init__.py +++ b/src/ape/types/__init__.py @@ -19,7 +19,14 @@ ) from ethpm_types.abi import EventABI from ethpm_types.source import Closure -from pydantic import BaseModel, BeforeValidator, field_validator, model_validator +from pydantic import BaseModel, BeforeValidator, field_serializer, field_validator, model_validator +from pydantic_core.core_schema import ( + CoreSchema, + ValidationInfo, + int_schema, + no_info_plain_validator_function, + plain_serializer_function_ser_schema, +) from typing_extensions import TypeAlias from web3.types import FilterParams @@ -251,6 +258,15 @@ def __eq__(self, other: Any) -> bool: return True + @field_serializer("event_arguments") + def _serialize_event_arguments(self, event_arguments, info): + """ + Because of an issue with BigInt in Pydantic, + (https://github.com/pydantic/pydantic/issues/10152) + we have to ensure these are regular ints. + """ + return {k: int(v) if isinstance(v, int) else v for k, v in event_arguments.items()} + class ContractLog(ExtraAttributesMixin, BaseContractLog): """ @@ -484,6 +500,36 @@ def __eq__(self, other: Any) -> bool: # Try from the other end, if hasn't already. return NotImplemented + @classmethod + def __get_pydantic_core_schema__(cls, value, handler=None) -> CoreSchema: + return no_info_plain_validator_function( + cls._validate, + serialization=plain_serializer_function_ser_schema( + cls._serialize, + info_arg=False, + return_schema=int_schema(), + ), + ) + + @staticmethod + def _validate(value: Any, info: Optional[ValidationInfo] = None) -> "CurrencyValueComparable": + # NOTE: For some reason, for this to work, it has to happen + # in an "after" validator, or else it always only `int` type on the model. + if value is None: + # Will fail if not optional. + # Type ignore because this is an hacky and unlikely situation. + return None # type: ignore + + elif isinstance(value, str) and " " in value: + return ManagerAccessMixin.conversion_manager.convert(value, int) + + # For models annotating with this type, we validate all integers into it. + return CurrencyValueComparable(value) + + @staticmethod + def _serialize(value): + return int(value) + CurrencyValueComparable.__name__ = int.__name__ diff --git a/tests/functional/geth/test_provider.py b/tests/functional/geth/test_provider.py index be7028757f..dc5e76a97e 100644 --- a/tests/functional/geth/test_provider.py +++ b/tests/functional/geth/test_provider.py @@ -95,6 +95,7 @@ def test_uri_when_configured(geth_provider, project, ethereum): assert actual_mainnet_uri == expected +@geth_process_test def test_uri_non_dev_and_not_configured(mocker, ethereum): """ If the URI was not configured and we are not using a dev @@ -547,6 +548,7 @@ def test_make_request_not_exists(geth_provider): geth_provider.make_request("ape_thisDoesNotExist") +@geth_process_test def test_geth_bin_not_found(): bin_name = "__NOT_A_REAL_EXECUTABLE_HOPEFULLY__" with pytest.raises(NodeSoftwareNotInstalledError): @@ -677,6 +679,7 @@ def test_trace_approach_config(project): assert provider.call_trace_approach is TraceApproach.GETH_STRUCT_LOG_PARSE +@geth_process_test def test_start(mocker, convert, project, geth_provider): amount = convert("100_000 ETH", int) spy = mocker.spy(GethDevProcess, "from_uri") diff --git a/tests/functional/test_contract_event.py b/tests/functional/test_contract_event.py index 59291d19a8..509828184b 100644 --- a/tests/functional/test_contract_event.py +++ b/tests/functional/test_contract_event.py @@ -8,7 +8,7 @@ from ape.api import ReceiptAPI from ape.exceptions import ProviderError -from ape.types import ContractLog +from ape.types import ContractLog, CurrencyValueComparable @pytest.fixture @@ -363,3 +363,22 @@ def test_info(solidity_contract_instance): {spec} """.strip() assert actual == expected + + +def test_model_dump(solidity_contract_container, owner): + # NOTE: deploying a new contract with a new number to lessen x-dist conflicts. + contract = owner.deploy(solidity_contract_container, 29620000000003) + + # First, get an event (a normal way). + number = int(10e18) + tx = contract.setNumber(number, sender=owner) + event = tx.events[0] + + # Next, invoke `.model_dump()` to get the serialized version. + log = event.model_dump() + actual = log["event_arguments"] + assert actual["newNum"] == number + + # This next assertion is important because of this Pydantic bug: + # https://github.com/pydantic/pydantic/issues/10152 + assert not isinstance(actual["newNum"], CurrencyValueComparable) diff --git a/tests/functional/test_types.py b/tests/functional/test_types.py index 779e57fa50..4e5ed11bd3 100644 --- a/tests/functional/test_types.py +++ b/tests/functional/test_types.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Optional import pytest from eth_utils import to_hex @@ -6,7 +6,7 @@ from hexbytes import HexBytes from pydantic import BaseModel, Field -from ape.types import AddressType, ContractLog, HexInt, LogFilter +from ape.types import AddressType, ContractLog, CurrencyValueComparable, HexInt, LogFilter from ape.utils import ZERO_ADDRESS TXN_HASH = "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa222222222222222222222222" @@ -131,11 +131,67 @@ class MyModel(BaseModel): class TestHexInt: - class MyModel(BaseModel): - ual: HexInt = 0 - ual_optional: Optional[HexInt] = Field(default=None, validate_default=True) - - act = MyModel.model_validate({"ual": "0x123"}) - expected = 291 # Base-10 form of 0x123. - assert act.ual == expected - assert act.ual_optional is None + def test_model(self): + class MyModel(BaseModel): + ual: HexInt = 0 + ual_optional: Optional[HexInt] = Field(default=None, validate_default=True) + + act = MyModel.model_validate({"ual": "0x123"}) + expected = 291 # Base-10 form of 0x123. + assert act.ual == expected + assert act.ual_optional is None + + +class TestCurrencyValueComparable: + def test_use_for_int_in_pydantic_model(self): + value = 100000000000000000000000000000000000000000000 + + class MyBasicModel(BaseModel): + val: int + + model = MyBasicModel.model_validate({"val": CurrencyValueComparable(value)}) + assert model.val == value + + # Ensure serializes. + dumped = model.model_dump() + assert dumped["val"] == value + + @pytest.mark.parametrize("mode", ("json", "python")) + def test_use_in_model_annotation(self, mode): + value = 100000000000000000000000000000000000000000000 + + class MyAnnotatedModel(BaseModel): + val: CurrencyValueComparable + val_optional: Optional[CurrencyValueComparable] + + model = MyAnnotatedModel.model_validate({"val": value, "val_optional": value}) + assert isinstance(model.val, CurrencyValueComparable) + assert model.val == value + + # Show can use currency-comparable + expected_currency_value = "100000000000000000000000000 ETH" + assert model.val == expected_currency_value + assert model.val_optional == expected_currency_value + + # Ensure serializes. + dumped = model.model_dump(mode=mode) + assert dumped["val"] == value + assert dumped["val_optional"] == value + + def test_validate_from_currency_value(self): + class MyAnnotatedModel(BaseModel): + val: CurrencyValueComparable + val_optional: Optional[CurrencyValueComparable] + val_in_dict: dict[str, Any] + + value = "100000000000000000000000000 ETH" + expected = 100000000000000000000000000000000000000000000 + data = { + "val": value, + "val_optional": value, + "val_in_dict": {"value": CurrencyValueComparable(expected)}, + } + model = MyAnnotatedModel.model_validate(data) + for actual in (model.val, model.val_optional, model.val_in_dict["value"]): + for ex in (value, expected): + assert actual == ex