Skip to content

Commit

Permalink
Fix array unpacking case
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana-s committed Feb 3, 2025
1 parent 8f2483e commit 5a17ef6
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 12 deletions.
3 changes: 2 additions & 1 deletion guppylang/checker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ class SubscriptAccess:
ty: Type
item_expr: ast.expr
getitem_call: ast.expr | None = None
setitem_call: ast.expr | None = None
# Store a temp variable for the RHS of an assignment so it can be assigned a port.
setitem_call: tuple[ast.expr, ast.expr] | None = None

@dataclass(frozen=True)
class Id:
Expand Down
5 changes: 3 additions & 2 deletions guppylang/checker/expr_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
return replace(place, parent=check_inout_arg_place(parent, ctx, node))
case SubscriptAccess(parent=parent, item=item, ty=ty):
# Check a call to the `__setitem__` instance function
rhs = with_type(ty, with_loc(node, InoutReturnSentinel(var=place)))
exp_sig = FunctionType(
[
FuncInput(parent.ty, InputFlags.Inout),
Expand All @@ -934,7 +935,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
setitem_args = [
with_type(parent.ty, with_loc(node, PlaceNode(parent))),
with_type(item.ty, with_loc(node, PlaceNode(item))),
with_type(ty, with_loc(node, InoutReturnSentinel(var=place))),
rhs,
]
setitem_call, _ = ExprSynthesizer(ctx).synthesize_instance_func(
setitem_args[0],
Expand All @@ -944,7 +945,7 @@ def check_inout_arg_place(place: Place, ctx: Context, node: PlaceNode) -> Place:
exp_sig,
True,
)
return replace(place, setitem_call=setitem_call)
return replace(place, setitem_call=(setitem_call, rhs))


def synthesize_call(
Expand Down
2 changes: 1 addition & 1 deletion guppylang/checker/linearity_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def _reassign_single_inout_arg(self, place: Place, node: AstNode) -> None:
# Places involving subscripts are given back by visiting the `__setitem__` call
if subscript := contains_subscript(place):
assert subscript.setitem_call is not None
self.visit(subscript.setitem_call)
self.visit(subscript.setitem_call[0])
self._reassign_single_inout_arg(subscript.parent, node)
else:
for leaf in leaf_places(place):
Expand Down
7 changes: 5 additions & 2 deletions guppylang/checker/stmt_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def _check_subscript_assign(
item_expr, item_ty = self._synth_expr(lhs.slice)
item = Variable(next(tmp_vars), item_ty, item_expr)

# Create and store a temp variable to ensure RHS has a wire during compilation.
tmp_rhs = self._check_assign(make_var(next(tmp_vars), rhs), rhs, rhs_ty)

parent = value.place

exp_set_sig = FunctionType(
Expand All @@ -209,7 +212,7 @@ def _check_subscript_assign(
setitem_args = [
with_type(parent.ty, with_loc(lhs, PlaceNode(parent))),
with_type(item.ty, with_loc(lhs, PlaceNode(item))),
rhs,
tmp_rhs,
]
setitem_call, _ = self._synth_instance_fun(
setitem_args[0],
Expand All @@ -221,7 +224,7 @@ def _check_subscript_assign(
)

place = SubscriptAccess(
parent, item, rhs_ty, item_expr, setitem_call=setitem_call
parent, item, rhs_ty, item_expr, setitem_call=(setitem_call, tmp_rhs)
)
return with_loc(lhs, with_type(rhs_ty, PlaceNode(place=place)))

Expand Down
3 changes: 2 additions & 1 deletion guppylang/compiler/expr_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def _update_inout_ports(
# `arg.place.parent` occurs as an arg of this call, so will also
# be recursively reassigned.
if subscript := contains_subscript(arg.place):
self.visit(subscript.setitem_call)
assert subscript.setitem_call is not None
self.visit(subscript.setitem_call[0])
assert next(inout_ports, None) is None, "Too many inout return ports"

def visit_LocalCall(self, node: LocalCall) -> Wire:
Expand Down
6 changes: 4 additions & 2 deletions guppylang/compiler/stmt_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def _assign(self, lhs: ast.expr, port: Wire) -> None:

@_assign.register
def _assign_place(self, lhs: PlaceNode, port: Wire) -> None:
self.dfg[lhs.place] = port
if (subscript := contains_subscript(lhs.place)) and isinstance(
lhs.place, SubscriptAccess
):
Expand All @@ -81,7 +80,10 @@ def _assign_place(self, lhs: PlaceNode, port: Wire) -> None:
self.dfg[subscript.item] = self.expr_compiler.compile(
subscript.item_expr, self.dfg
)
self.expr_compiler.visit(subscript.setitem_call)
self._assign(subscript.setitem_call[1], port)
self.expr_compiler.visit(subscript.setitem_call[0])
else:
self.dfg[lhs.place] = port

@_assign.register
def _assign_tuple(self, lhs: TupleUnpack, port: Wire) -> None:
Expand Down
4 changes: 1 addition & 3 deletions tests/integration/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,6 @@ def main() -> int:
run_int_fn(compiled, expected=5)


@pytest.mark.skip("TODO: Fix this")
def test_subscript_assign_unpacking_range(validate, run_int_fn):
module = GuppyModule("test")

Expand All @@ -538,10 +537,9 @@ def main() -> int:

compiled = module.compile()
validate(compiled)
run_int_fn(compiled, expected=4)
run_int_fn(compiled, expected=9)


@pytest.mark.skip("TODO: Fix this")
def test_subscript_assign_unpacking_array(validate, run_int_fn):
module = GuppyModule("test")

Expand Down

0 comments on commit 5a17ef6

Please sign in to comment.