Skip to content

Commit

Permalink
feat: allow using the throwaway ('_') var
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-makerx committed Mar 13, 2024
1 parent e21dc55 commit 5feb20a
Show file tree
Hide file tree
Showing 23 changed files with 915 additions and 17 deletions.
1 change: 1 addition & 0 deletions examples/sizes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ stubs/BigUInt 172 112 112
stubs/Bytes 1769 258 258
stubs/Uint64 371 8 8
template_variables/TemplateVariables 168 155 155
throwaway/Throwaway 45 31 31
too_many_permutations 108 107 107
transaction/Transaction 893 849 849
tuple_support/TupleSupport 442 294 294
Expand Down
4 changes: 2 additions & 2 deletions src/puya/awst/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ class AssignmentStatement(Statement):
value: Expression = attrs.field(validator=[lvalue_expr_validator])

def __attrs_post_init__(self) -> None:
if self.value.wtype != self.target.wtype:
if not wtypes.is_valid_assignment(self.target.wtype, self.value.wtype):
raise CodeError(
f"Assignment target type {self.target.wtype}"
f" differs from expression value type {self.value.wtype}",
Expand Down Expand Up @@ -1063,7 +1063,7 @@ def __init__(self, value: Expression, target: Lvalue, source_location: SourceLoc
"Tuple unpacking in assignment expressions is not supported",
target.source_location,
)
if value.wtype != target.wtype:
if not wtypes.is_valid_assignment(target.wtype, value.wtype):
raise CodeError(
f"Assignment target type {target.wtype}"
f" differs from expression value type {value.wtype}",
Expand Down
20 changes: 20 additions & 0 deletions src/puya/awst/wtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def is_valid_utf8_literal(value: object) -> typing.TypeGuard[str]:
lvalue=False,
)

throwaway_type: typing.Final = WType(
name="throwaway",
stub_name="typing.Any",
)

bool_wtype: typing.Final = WType(
name="bool",
stub_name="bool",
Expand Down Expand Up @@ -526,3 +531,18 @@ def arc4_to_avm_equivalent_wtype(arc4_wtype: WType) -> WType:
return bool_wtype

raise InternalError(f"Invalid arc4_wtype: {arc4_wtype}")


def is_valid_assignment(target: WType, source: WType) -> bool:
if target == throwaway_type:
return True
if (
isinstance(target, WTuple)
and isinstance(source, WTuple)
and len(target.types) == len(source.types)
):
return all(
is_valid_assignment(target_item, source_item)
for target_item, source_item in zip(target.types, source.types, strict=True)
)
return target == source
68 changes: 68 additions & 0 deletions src/puya/awst_build/eb/throwaway.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import typing
from collections.abc import Sequence

import mypy.nodes

from puya.awst import wtypes
from puya.awst.nodes import Expression, Literal, Lvalue, VarExpression
from puya.awst_build.eb.base import ExpressionBuilder
from puya.errors import CodeError, InternalError
from puya.parse import SourceLocation


class ThrowawayExpressionBuilder(ExpressionBuilder):
wtype = wtypes.throwaway_type

def __init__(self, expr: Expression):
super().__init__(expr.source_location)
if not isinstance(expr, VarExpression):
raise InternalError("Expected a VarExpression", expr.source_location)
self._expr = expr

def lvalue(self) -> Lvalue:
return self._expr

def rvalue(self) -> typing.Never:
self._raise_error(self.source_location)

def delete(self, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def index(self, index: ExpressionBuilder | Literal, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def call(
self,
args: Sequence[ExpressionBuilder | Literal],
arg_kinds: list[mypy.nodes.ArgKind],
arg_names: list[str | None],
location: SourceLocation,
original_expr: mypy.nodes.CallExpr,
) -> typing.Never:
self._raise_error(location)

def member_access(self, name: str, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def iterate(self) -> typing.Never:
self._raise_error(self.source_location)

def bool_eval(self, location: SourceLocation, *, negate: bool = False) -> typing.Never:
self._raise_error(location)

def unary_plus(self, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def unary_minus(self, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def bitwise_invert(self, location: SourceLocation) -> typing.Never:
self._raise_error(location)

def contains(
self, item: ExpressionBuilder | Literal, location: SourceLocation
) -> typing.Never:
self._raise_error(location)

def _raise_error(self, location: SourceLocation) -> typing.Never:
raise CodeError("'_' variables can only be assigned to", location)
7 changes: 6 additions & 1 deletion src/puya/awst_build/eb/type_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
named_int_constants,
struct,
template_variables,
throwaway,
transaction,
tuple as tuple_,
uint64,
Expand Down Expand Up @@ -133,6 +134,7 @@
wtypes.WGroupTransaction: transaction.GroupTransactionExpressionBuilder,
wtypes.WInnerTransaction: transaction.InnerTransactionExpressionBuilder,
wtypes.WInnerTransactionFields: transaction.InnerTxnParamsExpressionBuilder,
wtypes.throwaway_type: throwaway.ThrowawayExpressionBuilder,
}


Expand All @@ -149,5 +151,8 @@ def var_expression(expr: Expression) -> ExpressionBuilder:
try:
builder = WTYPE_TO_BUILDER[expr.wtype]
except KeyError:
builder = WTYPE_TO_BUILDER[type(expr.wtype)]
try:
builder = WTYPE_TO_BUILDER[type(expr.wtype)]
except KeyError as ex:
raise InternalError(f"Unhandled wtype: {expr.wtype}", expr.source_location) from ex
return builder(expr)
43 changes: 30 additions & 13 deletions src/puya/awst_build/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
VarExpression,
WhileLoop,
)
from puya.awst.wtypes import WType
from puya.awst_build import constants, intrinsic_data
from puya.awst_build.base_mypy_visitor import BaseMyPyVisitor
from puya.awst_build.context import ASTConversionModuleContext
Expand Down Expand Up @@ -120,6 +119,7 @@ def __init__(
func_loc = self._location(func_def)
self.contract_method_info = contract_method_info
self._is_bool_context = False
self._is_lvalue_context = False
self.func_def = func_def
self._precondition(
func_def.abstract_status == mypy.nodes.NOT_ABSTRACT,
Expand Down Expand Up @@ -161,7 +161,11 @@ def __init__(
func_loc,
)
# TODO: this should be more than just type?
self._symtable = dict[str, WType]()
# mypy has special behaviour to treat '_' as Any, so predefine the '_' symbol as
# a type that can only be assigned to
self._symtable = {
"_": wtypes.throwaway_type,
}
args = list[SubroutineArgument]()
for arg, arg_type in zip(mypy_args, mypy_arg_types, strict=True):
if arg.kind.is_star():
Expand Down Expand Up @@ -266,6 +270,18 @@ def _set_bool_context(self, *, is_bool_context: bool) -> Iterator[None]:
_enter_bool_context = partialmethod(_set_bool_context, is_bool_context=True)
_leave_bool_context = partialmethod(_set_bool_context, is_bool_context=False)

@contextlib.contextmanager
def _set_lvalue_context(self, *, is_lvalue_context: bool) -> Iterator[None]:
was_lvalue_context = self._is_lvalue_context
self._is_lvalue_context = is_lvalue_context
try:
yield
finally:
self._is_lvalue_context = was_lvalue_context

_enter_lvalue_context = partialmethod(_set_lvalue_context, is_lvalue_context=True)
_leave_lvalue_context = partialmethod(_set_lvalue_context, is_lvalue_context=False)

def visit_expression_stmt(self, stmt: mypy.nodes.ExpressionStmt) -> ExpressionStatement:
self._precondition(
stmt.line == stmt.expr.line
Expand Down Expand Up @@ -391,9 +407,10 @@ def _handle_state_proxy_assignment(
]

def resolve_lvalue(self, lvalue: mypy.nodes.Expression) -> Lvalue:
builder_or_literal = lvalue.accept(self)
builder = require_expression_builder(builder_or_literal)
return builder.lvalue()
with self._enter_lvalue_context():
builder_or_literal = lvalue.accept(self)
builder = require_expression_builder(builder_or_literal)
return builder.lvalue()

def empty_statement(self, _tmt: mypy.nodes.Statement) -> None:
return None
Expand Down Expand Up @@ -700,9 +717,6 @@ def _visit_ref_expr_maybe_aliased(
" as a singular lvalue in an assignment statement",
expr_loc,
)
if var_name == "_":
# TODO: ignore "_"
raise CodeError("_ is not currently supported as a variable name", expr_loc)
local_type = lazy_setdefault(
self._symtable,
key=var_name,
Expand Down Expand Up @@ -1146,11 +1160,14 @@ def visit_bytes_expr(self, expr: mypy.nodes.BytesExpr) -> Literal:

def visit_tuple_expr(self, mypy_expr: mypy.nodes.TupleExpr) -> ExpressionBuilder:
items = [
require_expression_builder(
mypy_item.accept(self),
msg="Python literals (other than True/False) are not valid as tuple elements",
).rvalue()
for mypy_item in mypy_expr.items
eb.lvalue() if self._is_lvalue_context else eb.rvalue()
for eb in (
require_expression_builder(
mypy_item.accept(self),
msg="Python literals (other than True/False) are not valid as tuple elements",
)
for mypy_item in mypy_expr.items
)
]
wtype = wtypes.WTuple.from_types(i.wtype for i in items)
tuple_expr = TupleExpression(
Expand Down
20 changes: 20 additions & 0 deletions test_cases/throwaway/contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from puyapy import Account, Bytes, Contract, Global, Txn, UInt64, subroutine


class Throwaway(Contract):
def approval_program(self) -> bool:
tup = get_tuple()
args, sender, _ = tup
_, _, approval = tup
assert sender == Global.creator_address
assert args == 0
assert approval
return True

def clear_state_program(self) -> bool:
return True


@subroutine
def get_tuple() -> tuple[UInt64, Account, Bytes]:
return Txn.num_app_args, Txn.sender, Txn.approval_program
48 changes: 48 additions & 0 deletions test_cases/throwaway/out/Throwaway.approval.mir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Op // Op Description Stack (out) X stack Source code Source line

#pragma version 10

// test_cases.throwaway.contract.Throwaway.approval_program() -> uint64:
main_block@0:
callsub get_tuple // {get_tuple}.0,{get_tuple}.1,{get_tuple}.2 get_tuple() throwaway/contract.py:6
cover 2 // store tup.2#0 to l-stack (no copy) tup.2#0,{get_tuple}.0,{get_tuple}.1 tup = get_tuple() throwaway/contract.py:6
// virtual: store tup.1#0 to l-stack (no copy) tup.2#0,tup.1#0,{get_tuple}.0 tup = get_tuple() throwaway/contract.py:6
// virtual: store tup.0#0 to l-stack (no copy) tup.2#0,tup.0#0,tup.1#0 tup = get_tuple() throwaway/contract.py:6
global CreatorAddress // tup.2#0,tup.0#0,tup.1#0,{global} Global.creator_address throwaway/contract.py:9
// virtual: store tmp%0#0 to l-stack (no copy) tup.2#0,tup.0#0,tup.1#0,tmp%0#0 Global.creator_address throwaway/contract.py:9
// virtual: load tup.1#0 from l-stack (no copy) tup.2#0,tup.0#0,tmp%0#0,tup.1#0 sender == Global.creator_address throwaway/contract.py:9
// virtual: load tmp%0#0 from l-stack (no copy) tup.2#0,tup.0#0,tup.1#0,tmp%0#0 sender == Global.creator_address throwaway/contract.py:9
== // tup.2#0,tup.0#0,{==} sender == Global.creator_address throwaway/contract.py:9
// virtual: store tmp%1#0 to l-stack (no copy) tup.2#0,tup.0#0,tmp%1#0 sender == Global.creator_address throwaway/contract.py:9
// virtual: load tmp%1#0 from l-stack (no copy) tup.2#0,tup.0#0,tmp%1#0 assert sender == Global.creator_address throwaway/contract.py:9
assert // tup.2#0,tup.0#0 assert sender == Global.creator_address throwaway/contract.py:9
// virtual: load tup.0#0 from l-stack (no copy) tup.2#0,tup.0#0 args == 0 throwaway/contract.py:10
! // tup.2#0,{!} args == 0 throwaway/contract.py:10
// virtual: store tmp%2#0 to l-stack (no copy) tup.2#0,tmp%2#0 args == 0 throwaway/contract.py:10
// virtual: load tmp%2#0 from l-stack (no copy) tup.2#0,tmp%2#0 assert args == 0 throwaway/contract.py:10
assert // tup.2#0 assert args == 0 throwaway/contract.py:10
// virtual: load tup.2#0 from l-stack (no copy) tup.2#0 approval throwaway/contract.py:11
len // {len} approval throwaway/contract.py:11
// virtual: store tmp%3#0 to l-stack (no copy) tmp%3#0 approval throwaway/contract.py:11
// virtual: load tmp%3#0 from l-stack (no copy) tmp%3#0 assert approval throwaway/contract.py:11
assert // assert approval throwaway/contract.py:11
int 1 // 1 True throwaway/contract.py:12
return // return True throwaway/contract.py:12


// test_cases.throwaway.contract.get_tuple() -> uint64, bytes, bytes:
get_tuple:
proto 0 3 // @subroutine\ndef get_tuple() -> tuple[UInt64, Account, Bytes]: throwaway/contract.py:18-19

get_tuple_block@0:
txn NumAppArgs // {txn} Txn.num_app_args throwaway/contract.py:20
// virtual: store tmp%0#0 to l-stack (no copy) tmp%0#0 Txn.num_app_args throwaway/contract.py:20
txn Sender // tmp%0#0,{txn} Txn.sender throwaway/contract.py:20
// virtual: store tmp%1#0 to l-stack (no copy) tmp%0#0,tmp%1#0 Txn.sender throwaway/contract.py:20
txn ApprovalProgram // tmp%0#0,tmp%1#0,{txn} Txn.approval_program throwaway/contract.py:20
// virtual: store tmp%2#0 to l-stack (no copy) tmp%0#0,tmp%1#0,tmp%2#0 Txn.approval_program throwaway/contract.py:20
uncover 2 // load tmp%0#0 from l-stack (no copy) tmp%1#0,tmp%2#0,tmp%0#0 return Txn.num_app_args, Txn.sender, Txn.approval_program throwaway/contract.py:20
uncover 2 // load tmp%1#0 from l-stack (no copy) tmp%2#0,tmp%0#0,tmp%1#0 return Txn.num_app_args, Txn.sender, Txn.approval_program throwaway/contract.py:20
uncover 2 // load tmp%2#0 from l-stack (no copy) tmp%0#0,tmp%1#0,tmp%2#0 return Txn.num_app_args, Txn.sender, Txn.approval_program throwaway/contract.py:20
retsub // tmp%0#0,tmp%1#0,tmp%2#0 return Txn.num_app_args, Txn.sender, Txn.approval_program throwaway/contract.py:20

Loading

0 comments on commit 5feb20a

Please sign in to comment.