Skip to content

Commit

Permalink
feat!: HugrMut::remove_node and SimpleReplacement return removed …
Browse files Browse the repository at this point in the history
…weights (#1516)

Closes #476 

BREAKING CHANGE: `remove_node` now returns an OpType, and the
`ApplyResult` of `SimpleReplacement` holds replaced nodes and weights.
  • Loading branch information
ss2165 authored Oct 2, 2024
1 parent e45ab5d commit ea8e818
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
8 changes: 4 additions & 4 deletions hugr-core/src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ pub trait HugrMut: HugrMutInternals {
self.hugr_mut().add_node_after(sibling, op)
}

/// Remove a node from the graph.
/// Remove a node from the graph and return the node weight.
///
/// # Panics
///
/// If the node is not in the graph, or if the node is the root node.
#[inline]
fn remove_node(&mut self, node: Node) {
fn remove_node(&mut self, node: Node) -> OpType {
panic_invalid_non_root(self, node);
self.hugr_mut().remove_node(node)
}
Expand Down Expand Up @@ -264,11 +264,11 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
node
}

fn remove_node(&mut self, node: Node) {
fn remove_node(&mut self, node: Node) -> OpType {
panic_invalid_non_root(self, node);
self.as_mut().hierarchy.remove(node.pg_index());
self.as_mut().graph.remove_node(node.pg_index());
self.as_mut().op_types.remove(node.pg_index());
self.as_mut().op_types.take(node.pg_index())
}

fn connect(
Expand Down
4 changes: 3 additions & 1 deletion hugr-core/src/hugr/rewrite/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ impl Rewrite for Replacement {
}

// 7. Remove remaining nodes
to_remove.into_iter().for_each(|n| h.remove_node(n));
to_remove.into_iter().for_each(|n| {
h.remove_node(n);
});
Ok(node_map)
}

Expand Down
16 changes: 9 additions & 7 deletions hugr-core/src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ impl SimpleReplacement {

impl Rewrite for SimpleReplacement {
type Error = SimpleReplacementError;
type ApplyResult = ();
type ApplyResult = Vec<(Node, OpType)>;
const UNCHANGED_ON_FAILURE: bool = true;

fn verify(&self, _h: &impl HugrView) -> Result<(), SimpleReplacementError> {
unimplemented!()
}

fn apply(mut self, h: &mut impl HugrMut) -> Result<(), SimpleReplacementError> {
fn apply(mut self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
let parent = self.subgraph.get_parent(h);
// 1. Check the parent node exists and is a DataflowParent.
if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
Expand Down Expand Up @@ -184,10 +184,12 @@ impl Rewrite for SimpleReplacement {
});

// 3.5. Remove all nodes in self.removal and edges between them.
for &node in self.subgraph.nodes() {
h.remove_node(node);
}
Ok(())
Ok(self
.subgraph
.nodes()
.iter()
.map(|&node| (node, h.remove_node(node)))
.collect())
}

#[inline]
Expand Down Expand Up @@ -831,7 +833,7 @@ pub(in crate::hugr::rewrite) mod test {
}

fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) {
h.apply_rewrite(rw).unwrap()
h.apply_rewrite(rw).unwrap();
}

fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) {
Expand Down

0 comments on commit ea8e818

Please sign in to comment.