Skip to content

Commit

Permalink
refactor(py): NodeIdx and PortOffset aliases for int (#1339)
Browse files Browse the repository at this point in the history
Small refactor adding a type alias for node indices and port offsets.
  • Loading branch information
aborgna-q authored Jul 24, 2024
1 parent 56c3f5f commit 67374ca
Show file tree
Hide file tree
Showing 8 changed files with 44 additions and 29 deletions.
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .hugr import Hugr, ParentBuilder

if TYPE_CHECKING:
from .node_port import Node, ToNode, Wire
from .node_port import Node, PortOffset, ToNode, Wire
from .tys import Type, TypeRow


Expand All @@ -29,7 +29,7 @@ def set_single_succ_outputs(self, *outputs: Wire) -> None:
u = self.load(val.Unit)
self.set_outputs(u, *outputs)

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> Type:
def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> Type:
src = p.out_port()
cfg_node = self.hugr[self.parent_node].parent
assert cfg_node is not None
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .cfg import Cfg
from .cond_loop import Conditional, If, TailLoop
from .node_port import Node, OutPort, ToNode, Wire
from .node_port import Node, OutPort, PortOffset, ToNode, Wire


DP = TypeVar("DP", bound=ops.DfParentOp)
Expand Down Expand Up @@ -554,7 +554,7 @@ def _get_dataflow_type(self, wire: Wire) -> tys.Type:
raise ValueError(msg)
return ty

def _wire_up_port(self, node: Node, offset: int, p: Wire) -> tys.Type:
def _wire_up_port(self, node: Node, offset: PortOffset, p: Wire) -> tys.Type:
src = p.out_port()
node_ancestor = _ancestral_sibling(self.hugr, src.node, node)
if node_ancestor is None:
Expand Down
12 changes: 7 additions & 5 deletions hugr-py/src/hugr/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from dataclasses import dataclass

from .node_port import NodeIdx


@dataclass
class NoSiblingAncestor(Exception):
"""No sibling ancestor of target for valid inter-graph edge."""

src: int
tgt: int
src: NodeIdx
tgt: NodeIdx

@property
def msg(self):
Expand All @@ -22,8 +24,8 @@ def msg(self):
class NotInSameCfg(Exception):
"""Source and target nodes are not in the same CFG."""

src: int
tgt: int
src: NodeIdx
tgt: NodeIdx

@property
def msg(self):
Expand All @@ -37,7 +39,7 @@ def msg(self):
class MismatchedExit(Exception):
"""Edge to exit block signature mismatch."""

src: int
src: NodeIdx

@property
def msg(self):
Expand Down
15 changes: 12 additions & 3 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
overload,
)

from hugr.node_port import Direction, InPort, Node, OutPort, ToNode, _SubPort
from hugr.node_port import (
Direction,
InPort,
Node,
NodeIdx,
OutPort,
PortOffset,
ToNode,
_SubPort,
)
from hugr.ops import Call, Const, DataflowOp, Module, Op
from hugr.serialization.ops import OpType as SerialOp
from hugr.serialization.serial_hugr import SerialHugr
Expand Down Expand Up @@ -545,7 +554,7 @@ def to_serial(self) -> SerialHugr:

def _serialize_link(
link: tuple[_SO, _SI],
) -> tuple[tuple[int, int], tuple[int, int]]:
) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]:
src, dst = link
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
return (src.port.node.idx, s), (dst.port.node.idx, d)
Expand All @@ -557,7 +566,7 @@ def _serialize_link(
edges=[_serialize_link(link) for link in self._links.items()],
)

def _constrain_offset(self, p: P) -> int:
def _constrain_offset(self, p: P) -> PortOffset:
# negative offsets are used to refer to the last port
if p.offset < 0:
match p.direction:
Expand Down
24 changes: 14 additions & 10 deletions hugr-py/src/hugr/node_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ class Direction(Enum):
OUTGOING = 1


NodeIdx = int
PortOffset = int


@dataclass(frozen=True, eq=True, order=True)
class _Port:
node: Node
offset: int
offset: PortOffset
direction: ClassVar[Direction]


Expand Down Expand Up @@ -72,21 +76,21 @@ def to_node(self) -> Node:
... # pragma: no cover

@overload
def __getitem__(self, index: int) -> OutPort: ...
def __getitem__(self, index: PortOffset) -> OutPort: ...
@overload
def __getitem__(self, index: slice) -> Iterator[OutPort]: ...
@overload
def __getitem__(self, index: tuple[int, ...]) -> Iterator[OutPort]: ...
def __getitem__(self, index: tuple[PortOffset, ...]) -> Iterator[OutPort]: ...

def __getitem__(
self, index: int | slice | tuple[int, ...]
self, index: PortOffset | slice | tuple[PortOffset, ...]
) -> OutPort | Iterator[OutPort]:
return self.to_node()._index(index)

def out_port(self) -> OutPort:
return OutPort(self.to_node(), 0)

def inp(self, offset: int) -> InPort:
def inp(self, offset: PortOffset) -> InPort:
"""Generate an input port for this node.
Args:
Expand All @@ -101,7 +105,7 @@ def inp(self, offset: int) -> InPort:
"""
return InPort(self.to_node(), offset)

def out(self, offset: int) -> OutPort:
def out(self, offset: PortOffset) -> OutPort:
"""Generate an output port for this node.
Args:
Expand All @@ -116,7 +120,7 @@ def out(self, offset: int) -> OutPort:
"""
return OutPort(self.to_node(), offset)

def port(self, offset: int, direction: Direction) -> InPort | OutPort:
def port(self, offset: PortOffset, direction: Direction) -> InPort | OutPort:
"""Generate a port in `direction` for this node with `offset`.
Examples:
Expand All @@ -137,14 +141,14 @@ class Node(ToNode):
with globally unique index.
"""

idx: int
idx: NodeIdx
_num_out_ports: int | None = field(default=None, compare=False)

def _index(
self, index: int | slice | tuple[int, ...]
self, index: PortOffset | slice | tuple[PortOffset, ...]
) -> OutPort | Iterator[OutPort]:
match index:
case int(index):
case PortOffset(index):
if self._num_out_ports is not None and index >= self._num_out_ports:
msg = "Index out of range"
raise IndexError(msg)
Expand Down
4 changes: 2 additions & 2 deletions hugr-py/src/hugr/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import hugr.serialization.ops as sops
from hugr import tys, val
from hugr.node_port import Direction, InPort, Node, OutPort, Wire
from hugr.node_port import Direction, InPort, Node, OutPort, PortOffset, Wire
from hugr.utils import ser_it

if TYPE_CHECKING:
Expand Down Expand Up @@ -968,7 +968,7 @@ def to_serial(self, parent: Node) -> sops.Call:
def num_out(self) -> int:
return len(self.signature.body.output)

def _function_port_offset(self) -> int:
def _function_port_offset(self) -> PortOffset:
return len(self.signature.body.input)

def port_kind(self, port: InPort | OutPort) -> tys.Kind:
Expand Down
5 changes: 2 additions & 3 deletions hugr-py/src/hugr/serialization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pydantic import ConfigDict, Field, RootModel

from hugr.node_port import NodeIdx # noqa: TCH001 # pydantic needs this alias in scope
from hugr.utils import deser_it

from . import tys as stys
Expand All @@ -28,14 +29,12 @@
model_rebuild as tys_model_rebuild,
)

NodeID = int


class BaseOp(ABC, ConfiguredBaseModel):
"""Base class for ops that store their node's input/output types."""

# Parent node index of node the op belongs to, used only at serialization time
parent: NodeID
parent: NodeIdx

def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None:
"""Hook to insert type information from the input and output ports into the
Expand Down
5 changes: 3 additions & 2 deletions hugr-py/src/hugr/serialization/serial_hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import hugr
from hugr import get_serialisation_version
from hugr.node_port import NodeIdx, PortOffset

from .ops import NodeID, OpType
from .ops import OpType
from .ops import classes as ops_classes
from .tys import ConfiguredBaseModel, model_rebuild

Port = tuple[NodeID, int | None] # (node, offset)
Port = tuple[NodeIdx, PortOffset | None]
Edge = tuple[Port, Port]

VersionField = Field(
Expand Down

0 comments on commit 67374ca

Please sign in to comment.