diff --git a/src/puya/awst/wtypes.py b/src/puya/awst/wtypes.py index c52d23624..c5902db18 100644 --- a/src/puya/awst/wtypes.py +++ b/src/puya/awst/wtypes.py @@ -335,20 +335,30 @@ def _m_validator(self, _attribute: object, m: int) -> None: raise CodeError("Precision must be between 1 and 160 inclusive", self.source_location) -def _required_arc4_wtypes(wtypes: Iterable[WType]) -> tuple[ARC4Type, ...]: - result = [] - for wtype in wtypes: - if not isinstance(wtype, ARC4Type): - raise CodeError(f"expected ARC4 type: {wtype}") - result.append(wtype) - return tuple(result) +def _tuple_requires_arc4_types(wtypes: Iterable[WType], node: WType) -> tuple[ARC4Type, ...]: + return tuple( + _narrow_arc4_wtype(wtype, "tuple", wtype.source_location or node.source_location) + for wtype in wtypes + ) + + +def _array_requires_arc4_type(wtype: WType, node: WType) -> ARC4Type: + return _narrow_arc4_wtype(wtype, "array", wtype.source_location or node.source_location) + + +def _narrow_arc4_wtype(wtype: WType, context: str, loc: SourceLocation | None) -> ARC4Type: + if isinstance(wtype, ARC4Type): + return wtype + raise CodeError(f"{wtype} is not supported in an ARC4 {context}", loc) @typing.final @attrs.frozen(kw_only=True) class ARC4Tuple(ARC4Type): source_location: SourceLocation | None = attrs.field(default=None, eq=False) - types: tuple[ARC4Type, ...] = attrs.field(converter=_required_arc4_wtypes) + types: tuple[ARC4Type, ...] = attrs.field( + converter=attrs.Converter(_tuple_requires_arc4_types, takes_self=True) # type: ignore[misc, call-overload] + ) name: str = attrs.field(init=False) arc4_name: str = attrs.field(init=False, eq=False) immutable: bool = attrs.field(init=False) @@ -387,15 +397,11 @@ def _is_arc4_encodeable_tuple( ) -def _expect_arc4_type(wtype: WType) -> ARC4Type: - if not isinstance(wtype, ARC4Type): - raise CodeError(f"expected ARC4 type: {wtype}") - return wtype - - @attrs.frozen(kw_only=True) class ARC4Array(ARC4Type): - element_type: ARC4Type = attrs.field(converter=_expect_arc4_type) + element_type: ARC4Type = attrs.field( + converter=attrs.Converter(_array_requires_arc4_type, takes_self=True) # type: ignore[misc, call-overload] + ) native_type: WType | None = None immutable: bool = False diff --git a/tests/test_expected_output/arc4.test b/tests/test_expected_output/arc4.test index 33a6f4d94..5de1f23a5 100644 --- a/tests/test_expected_output/arc4.test +++ b/tests/test_expected_output/arc4.test @@ -65,8 +65,8 @@ def wrong_arg_type3() -> None: @subroutine def wrong_arg_type4() -> None: - arc4.emit("Event(string,uint8)", Event(arc4.String("a"), arc4.UInt8(1))) ## E: expected type algopy.arc4.String, got type test_emit_errors.Event \ - ## E: expected 2 ABI arguments, got 1 + arc4.emit("Event(string,uint8)", Event(arc4.String("a"), arc4.UInt8(1))) ## E: expected 2 ABI arguments, got 1 \ + ## E: expected type algopy.arc4.String, got type test_emit_errors.Event @subroutine def wrong_arg_type5() -> None: @@ -655,3 +655,14 @@ def c() -> None: @subroutine def d() -> None: assert not arc4.DynamicArray[int]() ## E: Python literals of type int cannot be used as runtime values + +## case: abimethod_reference_type_in_tuple +from algopy import * + +class ARC4Reference(arc4.ARC4Contract): + @arc4.abimethod() + def test(self, + arg: tuple[Account, Asset], ## E: account is not supported in an ARC4 tuple + arg2: tuple[Account, Application], + ) -> None: + pass