From b3062e9b8d0f081c1ce65295e30066aa16c39ee4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Mon, 6 Nov 2023 10:57:33 +0000 Subject: [PATCH] refactor: NodeType constructors, adding new_auto (#635) * Rename NodeType::open_extensions to NodeType::new_open * Rename NodeType::pure to NodeType::new_pure * Add NodeType::new_auto, which uses Pure for module-ops and Open for others * Remove special-case in infer.rs solving some module-ops to empty set * Switch builder/HugrMut methods from new_open to new_auto --- src/builder.rs | 2 +- src/builder/build_traits.rs | 4 ++-- src/builder/cfg.rs | 2 +- src/builder/conditional.rs | 4 ++-- src/builder/dataflow.rs | 2 +- src/builder/module.rs | 2 +- src/builder/tail_loop.rs | 2 +- src/extension/infer.rs | 24 +++++++++-------------- src/hugr.rs | 21 ++++++++++++++------ src/hugr/hugrmut.rs | 6 +++--- src/hugr/serialize.rs | 9 +++------ src/hugr/validate.rs | 36 ++++++++++++++++++++-------------- src/hugr/views/root_checked.rs | 6 +++--- src/hugr/views/sibling.rs | 4 ++-- 14 files changed, 65 insertions(+), 59 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 52d2b7ca4..b2b8d7371 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -146,7 +146,7 @@ pub(crate) mod test { /// inference. Using DFGBuilder will default to a root node with an open /// extension variable pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr { - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: signature.clone(), })); hugr.add_op_with_parent( diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 6294df033..950a85903 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -200,7 +200,7 @@ pub trait Dataflow: Container { op: impl Into, input_wires: impl IntoIterator, ) -> Result, BuildError> { - self.add_dataflow_node(NodeType::open_extensions(op), input_wires) + self.add_dataflow_node(NodeType::new_auto(op), input_wires) } /// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the @@ -628,7 +628,7 @@ fn add_op_with_wires( optype: impl Into, inputs: Vec, ) -> Result<(Node, usize), BuildError> { - add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs) + add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs) } fn add_node_with_wires( diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index eb168082e..809093652 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -62,7 +62,7 @@ impl CFGBuilder { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(cfg_op)); + let base = Hugr::new(NodeType::new_open(cfg_op)); let cfg_node = base.root(); CFGBuilder::create(base, cfg_node, signature.input, signature.output) } diff --git a/src/builder/conditional.rs b/src/builder/conditional.rs index da8808eea..6a46b5e55 100644 --- a/src/builder/conditional.rs +++ b/src/builder/conditional.rs @@ -176,7 +176,7 @@ impl ConditionalBuilder { extension_delta, }; // TODO: Allow input extensions to be specified - let base = Hugr::new(NodeType::open_extensions(op)); + let base = Hugr::new(NodeType::new_open(op)); let conditional_node = base.root(); Ok(ConditionalBuilder { @@ -194,7 +194,7 @@ impl CaseBuilder { let op = ops::Case { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(op)); + let base = Hugr::new(NodeType::new_open(op)); let root = base.root(); let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?; diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index 9088b3c5c..03480d13b 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -79,7 +79,7 @@ impl DFGBuilder { let dfg_op = ops::DFG { signature: signature.clone(), }; - let base = Hugr::new(NodeType::open_extensions(dfg_op)); + let base = Hugr::new(NodeType::new_open(dfg_op)); let root = base.root(); DFGBuilder::create_with_io(base, root, signature, None) } diff --git a/src/builder/module.rs b/src/builder/module.rs index 83b08a32c..a78c047d7 100644 --- a/src/builder/module.rs +++ b/src/builder/module.rs @@ -90,7 +90,7 @@ impl + AsRef> ModuleBuilder { }; self.hugr_mut().replace_op( f_node, - NodeType::pure(ops::FuncDefn { + NodeType::new_pure(ops::FuncDefn { name, signature: signature.clone(), }), diff --git a/src/builder/tail_loop.rs b/src/builder/tail_loop.rs index a8a07eb4a..5eb8286e1 100644 --- a/src/builder/tail_loop.rs +++ b/src/builder/tail_loop.rs @@ -82,7 +82,7 @@ impl TailLoopBuilder { rest: inputs_outputs.into(), }; // TODO: Allow input extensions to be specified - let base = Hugr::new(NodeType::open_extensions(tail_loop.clone())); + let base = Hugr::new(NodeType::new_open(tail_loop.clone())); let root = base.root(); Self::create_with_io(base, root, &tail_loop) } diff --git a/src/extension/infer.rs b/src/extension/infer.rs index db0b66694..d3fa0fa63 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -316,12 +316,6 @@ impl UnificationContext { m_output, node_type.op_signature().extension_reqs, ); - if matches!( - node_type.tag(), - OpTag::Alias | OpTag::Function | OpTag::FuncDefn - ) { - self.add_solution(m_input, ExtensionSet::new()); - } } // We have a solution for everything! Some(sig) => { @@ -723,7 +717,7 @@ mod test { signature: main_sig, }; - let root_node = NodeType::open_extensions(op); + let root_node = NodeType::new_open(op); let mut hugr = Hugr::new(root_node); let input = ops::Input::new(type_row![NAT, NAT]); @@ -833,21 +827,21 @@ mod test { // This generates a solution that causes validation to fail // because of a missing lift node fn missing_lift_node() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&ExtensionSet::singleton(&A)), })); let input = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Input { + NodeType::new_pure(ops::Input { types: type_row![NAT], }), )?; let output = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Output { + NodeType::new_pure(ops::Output { types: type_row![NAT], }), )?; @@ -1049,7 +1043,7 @@ mod test { extension_delta: rs.clone(), }; - let mut hugr = Hugr::new(NodeType::pure(op)); + let mut hugr = Hugr::new(NodeType::new_pure(op)); let conditional_node = hugr.root(); let case_op = ops::Case { @@ -1084,7 +1078,7 @@ mod test { fn extension_adding_sequence() -> Result<(), Box> { let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::DFG { signature: df_sig .clone() .with_extension_delta(&ExtensionSet::from_iter([A, B])), @@ -1255,7 +1249,7 @@ mod test { let b = ExtensionSet::singleton(&B); let c = ExtensionSet::singleton(&C); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), })); @@ -1353,7 +1347,7 @@ mod test { /// +--------------------+ #[test] fn multi_entry() -> Result<(), Box> { - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? })); let cfg = hugr.root(); @@ -1436,7 +1430,7 @@ mod test { ) -> Result> { let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext); - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + let mut hugr = Hugr::new(NodeType::new_open(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]) .with_extension_delta(&hugr_delta), })); diff --git a/src/hugr.rs b/src/hugr.rs index d6dcd5ec6..128eee522 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -82,7 +82,7 @@ impl NodeType { } /// Instantiate an OpType with no input extensions - pub fn pure(op: impl Into) -> Self { + pub fn new_pure(op: impl Into) -> Self { NodeType { op: op.into(), input_extensions: Some(ExtensionSet::new()), @@ -91,13 +91,24 @@ impl NodeType { /// Instantiate an OpType with an unknown set of input extensions /// (to be inferred later) - pub fn open_extensions(op: impl Into) -> Self { + pub fn new_open(op: impl Into) -> Self { NodeType { op: op.into(), input_extensions: None, } } + /// Instantiate an [OpType] with the default set of input extensions + /// for that OpType. + pub fn new_auto(op: impl Into) -> Self { + let op = op.into(); + if OpTag::ModuleOp.is_superset(op.tag()) { + Self::new_pure(op) + } else { + Self::new_open(op) + } + } + /// Use the input extensions to calculate the concrete signature of the node pub fn signature(&self) -> Option { self.input_extensions @@ -119,9 +130,7 @@ impl NodeType { pub fn input_extensions(&self) -> Option<&ExtensionSet> { self.input_extensions.as_ref() } -} -impl NodeType { /// Gets the underlying [OpType] i.e. without any [input_extensions] /// /// [input_extensions]: NodeType::input_extensions @@ -153,7 +162,7 @@ impl OpType { impl Default for Hugr { fn default() -> Self { - Self::new(NodeType::pure(crate::ops::Module)) + Self::new(NodeType::new_pure(crate::ops::Module)) } } @@ -239,7 +248,7 @@ impl Hugr { /// Add a node to the graph, with the default conversion from OpType to NodeType pub(crate) fn add_op(&mut self, op: impl Into) -> Node { - self.add_node(NodeType::open_extensions(op)) + self.add_node(NodeType::new_auto(op)) } /// Add a node to the graph. diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index fca006b5d..3e1ef81dc 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -37,7 +37,7 @@ pub trait HugrMut: HugrMutInternals { parent: Node, op: impl Into, ) -> Result { - self.add_node_with_parent(parent, NodeType::open_extensions(op)) + self.add_node_with_parent(parent, NodeType::new_auto(op)) } /// Add a node to the graph with a parent in the hierarchy. @@ -217,7 +217,7 @@ impl + AsMut> HugrMut for T { } fn add_op_before(&mut self, sibling: Node, op: impl Into) -> Result { - self.add_node_before(sibling, NodeType::open_extensions(op)) + self.add_node_before(sibling, NodeType::new_auto(op)) } fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result { @@ -620,7 +620,7 @@ mod test { { let f_in = hugr - .add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT]))) + .add_node_with_parent(f, NodeType::new_pure(ops::Input::new(type_row![NAT]))) .unwrap(); let f_out = hugr .add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT])) diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 74e380367..edc517b0c 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -222,10 +222,7 @@ impl TryFrom for Hugr { for node_ser in nodes { hugr.add_node_with_parent( node_ser.parent, - match node_ser.input_extensions { - None => NodeType::open_extensions(node_ser.op), - Some(rs) => NodeType::new(node_ser.op, rs), - }, + NodeType::new(node_ser.op, node_ser.input_extensions), )?; } @@ -332,11 +329,11 @@ pub mod test { let mut h = Hierarchy::new(); let mut op_types = UnmanagedDenseMap::new(); - op_types[root] = NodeType::open_extensions(gen_optype(&g, root)); + op_types[root] = NodeType::new_open(gen_optype(&g, root)); for n in [a, b, c] { h.push_child(n, root).unwrap(); - op_types[n] = NodeType::pure(gen_optype(&g, n)); + op_types[n] = NodeType::new_pure(gen_optype(&g, n)); } let hg = Hugr { diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index be8d5062d..9cca30c54 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -853,7 +853,7 @@ mod test { Err(ValidationError::NoParent { node }) => assert_eq!(node, other) ); b.set_parent(other, root).unwrap(); - b.replace_op(other, NodeType::pure(declare_op)).unwrap(); + b.replace_op(other, NodeType::new_pure(declare_op)).unwrap(); b.add_ports(other, Direction::Outgoing, 1); assert_eq!(b.validate(&EMPTY_REG), Ok(())); @@ -872,7 +872,7 @@ mod test { fn leaf_root() { let leaf_op: OpType = LeafOp::Noop { ty: USIZE_T }.into(); - let b = Hugr::new(NodeType::pure(leaf_op)); + let b = Hugr::new(NodeType::new_pure(leaf_op)); assert_eq!(b.validate(&EMPTY_REG), Ok(())); } @@ -883,7 +883,7 @@ mod test { } .into(); - let mut b = Hugr::new(NodeType::pure(dfg_op)); + let mut b = Hugr::new(NodeType::new_pure(dfg_op)); let root = b.root(); add_df_children(&mut b, root, 1); assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); @@ -956,7 +956,7 @@ mod test { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, NodeType::pure(LeafOp::Noop { ty: NAT })) + b.replace_op(output, NodeType::new_pure(LeafOp::Noop { ty: NAT })) .unwrap(); assert_matches!( b.validate(&EMPTY_REG), @@ -964,8 +964,11 @@ mod test { ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, NodeType::pure(ops::Output::new(type_row![BOOL_T]))) - .unwrap(); + b.replace_op( + output, + NodeType::new_pure(ops::Output::new(type_row![BOOL_T])), + ) + .unwrap(); assert_matches!( b.validate(&EMPTY_REG), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) @@ -973,14 +976,14 @@ mod test { ); b.replace_op( output, - NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), ) .unwrap(); // After fixing the output back, replace the copy with an output op b.replace_op( copy, - NodeType::pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), ) .unwrap(); assert_matches!( @@ -1007,7 +1010,7 @@ mod test { b.validate(&EMPTY_REG).unwrap(); b.replace_op( copy, - NodeType::pure(ops::CFG { + NodeType::new_pure(ops::CFG { signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), }), ) @@ -1063,7 +1066,7 @@ mod test { // Change the types in the BasicBlock node to work on qubits instead of bits b.replace_op( block, - NodeType::pure(ops::BasicBlock::DFB { + NodeType::new_pure(ops::BasicBlock::DFB { inputs: type_row![Q], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![Q], @@ -1074,11 +1077,14 @@ mod test { let mut block_children = b.hierarchy.children(block.pg_index()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, NodeType::pure(ops::Input::new(type_row![Q]))) - .unwrap(); + b.replace_op( + block_input, + NodeType::new_pure(ops::Input::new(type_row![Q])), + ) + .unwrap(); b.replace_op( block_output, - NodeType::pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), + NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), ) .unwrap(); assert_matches!( @@ -1310,12 +1316,12 @@ mod test { let main_signature = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - let mut hugr = Hugr::new(NodeType::pure(ops::DFG { + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { signature: main_signature, })); let input = hugr.add_node_with_parent( hugr.root(), - NodeType::pure(ops::Input { + NodeType::new_pure(ops::Input { types: type_row![NAT], }), )?; diff --git a/src/hugr/views/root_checked.rs b/src/hugr/views/root_checked.rs index 26815e8ed..6b3f7aba3 100644 --- a/src/hugr/views/root_checked.rs +++ b/src/hugr/views/root_checked.rs @@ -79,7 +79,7 @@ mod test { #[test] fn root_checked() { - let root_type = NodeType::pure(ops::DFG { + let root_type = NodeType::new_pure(ops::DFG { signature: FunctionType::new(vec![], vec![]), }); let mut h = Hugr::new(root_type.clone()); @@ -94,7 +94,7 @@ mod test { let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); // That is a HugrMutInternal, so we can try: let root = dfg_v.root(); - let bb = NodeType::pure(BasicBlock::DFB { + let bb = NodeType::new_pure(BasicBlock::DFB { inputs: type_row![], other_outputs: type_row![], tuple_sum_rows: vec![type_row![]], @@ -129,7 +129,7 @@ mod test { let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap(); // And it's a HugrMut: - let nodetype = NodeType::pure(LeafOp::MakeTuple { tys: type_row![] }); + let nodetype = NodeType::new_pure(LeafOp::MakeTuple { tys: type_row![] }); bb_v.add_node_with_parent(bb_v.root(), nodetype).unwrap(); } } diff --git a/src/hugr/views/sibling.rs b/src/hugr/views/sibling.rs index 69a2da4f5..77f74f39b 100644 --- a/src/hugr/views/sibling.rs +++ b/src/hugr/views/sibling.rs @@ -454,7 +454,7 @@ mod test { ); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); - let bad_nodetype = NodeType::open_extensions(crate::ops::CFG { signature }); + let bad_nodetype = NodeType::new_open(crate::ops::CFG { signature }); assert_eq!( sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), Err(HugrError::InvalidTag { @@ -471,7 +471,7 @@ mod test { #[rstest] fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { let root = simple_dfg_hugr.root(); - let case_nodetype = NodeType::open_extensions(crate::ops::Case { + let case_nodetype = NodeType::new_open(crate::ops::Case { signature: simple_dfg_hugr.root_type().op_signature(), }); let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap();