diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 224fc47f1..289559539 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -2,7 +2,7 @@ use crate::{ extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc}, hugr::{IdentList, NodeMetadataMap}, - ops::{DataflowBlock, OpName, OpTrait, OpType}, + ops::{constant::CustomSerialized, DataflowBlock, OpName, OpTrait, OpType, Value}, types::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, @@ -20,6 +20,7 @@ pub(crate) const OP_FUNC_CALL_INDIRECT: &str = "func.call-indirect"; const TERM_PARAM_TUPLE: &str = "param.tuple"; const TERM_JSON: &str = "prelude.json"; const META_DESCRIPTION: &str = "docs.description"; +const TERM_JSON_CONST: &str = "prelude.const-json"; /// Export a [`Hugr`] graph to its representation in the model. pub fn export_hugr<'a>(hugr: &'a Hugr, bump: &'a Bump) -> model::Module<'a> { @@ -71,7 +72,12 @@ struct Context<'a> { implicit_imports: FxHashMap<&'a str, model::NodeId>, /// Map from node ids in the [`Hugr`] to the corresponding node ids in the model. - node_indices: FxHashMap, + node_to_id: FxHashMap, + + /// Mapping from node ids in the [`Hugr`] to the corresponding model nodes. + id_to_node: FxHashMap, + // TODO: Once this module matures, we should consider adding an auxiliary structure + // that ensures that the `node_to_id` and `id_to_node` maps stay in sync. } impl<'a> Context<'a> { @@ -89,7 +95,8 @@ impl<'a> Context<'a> { local_constraints: Vec::new(), symbols: model::scope::SymbolTable::default(), implicit_imports: FxHashMap::default(), - node_indices: FxHashMap::default(), + node_to_id: FxHashMap::default(), + id_to_node: FxHashMap::default(), links: model::scope::LinkTable::default(), } } @@ -104,11 +111,13 @@ impl<'a> Context<'a> { let mut children = Vec::with_capacity(hugr_children.size_hint().0); for child in hugr_children.clone() { - children.push(self.export_node_shallow(child)); + if let Some(child_id) = self.export_node_shallow(child) { + children.push(child_id); + } } - for (child, child_node_id) in hugr_children.zip(children.iter().copied()) { - self.export_node_deep(child, child_node_id); + for child in &children { + self.export_node_deep(*child); } let mut all_children = BumpVec::with_capacity_in( @@ -226,11 +235,25 @@ impl<'a> Context<'a> { result } - fn export_node_shallow(&mut self, node: Node) -> model::NodeId { + fn export_node_shallow(&mut self, node: Node) -> Option { + let optype = self.hugr.get_optype(node); + + // We skip nodes that are not exported as nodes in the model. + if let OpType::Const(_) + | OpType::Input(_) + | OpType::Output(_) + | OpType::ExitBlock(_) + | OpType::Case(_) = optype + { + return None; + } + let node_id = self.module.insert_node(model::Node::default()); - self.node_indices.insert(node, node_id); + self.node_to_id.insert(node, node_id); + self.id_to_node.insert(node_id, node); - let symbol = match self.hugr.get_optype(node) { + // We record the name of the symbol defined by the node, if any. + let symbol = match optype { OpType::FuncDefn(func_defn) => Some(func_defn.name.as_str()), OpType::FuncDecl(func_decl) => Some(func_decl.name.as_str()), OpType::AliasDecl(alias_decl) => Some(alias_decl.name.as_str()), @@ -244,16 +267,17 @@ impl<'a> Context<'a> { .expect("duplicate symbol"); } - node_id + Some(node_id) } - fn export_node_deep(&mut self, node: Node, node_id: model::NodeId) { + fn export_node_deep(&mut self, node_id: model::NodeId) { // We insert a dummy node with the invalid operation at this point to reserve // the node id. This is necessary to establish the correct node id for the // local scope introduced by some operations. We will overwrite this node later. let mut params: &[_] = &[]; let mut regions: &[_] = &[]; + let node = self.id_to_node[&node_id]; let optype = self.hugr.get_optype(node); let operation = match optype { @@ -358,7 +382,7 @@ impl<'a> Context<'a> { OpType::Call(call) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let symbol = self.node_indices[&node]; + let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg))); let args = args.into_bump_slice(); @@ -370,7 +394,7 @@ impl<'a> Context<'a> { OpType::LoadFunction(load) => { // TODO: If the node is not connected to a function, we should do better than panic. let node = self.connected_function(node).unwrap(); - let symbol = self.node_indices[&node]; + let symbol = self.node_to_id[&node]; let mut args = BumpVec::new_in(self.bump); args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg))); @@ -380,8 +404,24 @@ impl<'a> Context<'a> { model::Operation::LoadFunc { func } } - OpType::Const(_) => todo!("Export const nodes?"), - OpType::LoadConstant(_) => todo!("Export load constant?"), + OpType::Const(_) => { + unreachable!("const nodes are filtered out by `export_node_shallow`") + } + + OpType::LoadConstant(_) => { + // TODO: If the node is not connected to a constant, we should do better than panic. + let const_node = self.hugr.static_source(node).unwrap(); + let const_node_op = self.hugr.get_optype(const_node); + + let OpType::Const(const_node_data) = const_node_op else { + panic!("expected `LoadConstant` node to be connected to a `Const` node"); + }; + + // TODO: Share the constant value between all nodes that load it. + + let value = self.export_value(&const_node_data.value); + model::Operation::Const { value } + } OpType::CallIndirect(_) => model::Operation::CustomFull { operation: self.resolve_symbol(OP_FUNC_CALL_INDIRECT), @@ -414,18 +454,6 @@ impl<'a> Context<'a> { .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - // PERFORMANCE: Currently the API does not appear to allow to get the extension - // set without copying it. - // NOTE: We assume here that the extension set of the dfg region must be the same - // as that of the node. This might change in the future. - let extensions = self.export_ext_set(&op.extension_delta()); - - if let Some(region) = - self.export_dfg_if_present(node, extensions, model::ScopeClosure::Closed) - { - regions = self.bump.alloc_slice_copy(&[region]); - } - model::Operation::CustomFull { operation } } @@ -436,18 +464,6 @@ impl<'a> Context<'a> { .bump .alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg))); - // PERFORMANCE: Currently the API does not appear to allow to get the extension - // set without copying it. - // NOTE: We assume here that the extension set of the dfg region must be the same - // as that of the node. This might change in the future. - let extensions = self.export_ext_set(&op.extension_delta()); - - if let Some(region) = - self.export_dfg_if_present(node, extensions, model::ScopeClosure::Closed) - { - regions = self.bump.alloc_slice_copy(&[region]); - } - model::Operation::CustomFull { operation } } }; @@ -549,7 +565,7 @@ impl<'a> Context<'a> { for (name, value) in opdef.iter_misc() { let name = self.bump.alloc_str(name); - let value = self.export_json(value); + let value = self.export_json_meta(value); meta.push(model::MetaItem { name, value }); } @@ -596,22 +612,6 @@ impl<'a> Context<'a> { }) } - /// Create a region from the given node's children, if it has any. - /// - /// See [`Self::export_dfg`]. - pub fn export_dfg_if_present( - &mut self, - node: Node, - extensions: model::TermId, - closure: model::ScopeClosure, - ) -> Option { - if self.hugr.children(node).next().is_none() { - None - } else { - Some(self.export_dfg(node, extensions, closure)) - } - } - /// Creates a data flow region from the given node's children. /// /// `Input` and `Output` nodes are used to determine the source and target ports of the region. @@ -628,44 +628,39 @@ impl<'a> Context<'a> { self.links.enter(region); } - let region_children = { - let children = self.hugr.children(node); - - // We skip the first two children, which are the `Input` and `Output` nodes. - // These nodes are not exported as model nodes themselves, but are used to determine - // the region's sources and targets. - let mut region_children = - BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); - for child in children.skip(2) { - region_children.push(self.export_node_shallow(child)); - } - region_children.into_bump_slice() - }; + let mut sources: &[_] = &[]; + let mut targets: &[_] = &[]; + let mut input_types = None; + let mut output_types = None; - let mut children = self.hugr.children(node); - - // The first child is an `Input` node, which we use to determine the region's sources. - let input_node = children.next().unwrap(); - let OpType::Input(input_op) = self.hugr.get_optype(input_node) else { - panic!("expected an `Input` node as the first child node"); - }; - let sources = self.make_ports(input_node, Direction::Outgoing, input_op.types.len()); + let children = self.hugr.children(node); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); - // The second child is an `Output` node, which we use to determine the region's targets. - let output_node = children.next().unwrap(); - let OpType::Output(output_op) = self.hugr.get_optype(output_node) else { - panic!("expected an `Output` node as the second child node"); - }; - let targets = self.make_ports(output_node, Direction::Incoming, output_op.types.len()); + for child in children { + match self.hugr.get_optype(child) { + OpType::Input(input) => { + sources = self.make_ports(child, Direction::Outgoing, input.types.len()); + input_types = Some(&input.types); + } + OpType::Output(output) => { + targets = self.make_ports(child, Direction::Incoming, output.types.len()); + output_types = Some(&output.types); + } + _ => { + if let Some(child_id) = self.export_node_shallow(child) { + region_children.push(child_id); + } + } + } + } - // Export the remaining children of the node. - for (child, child_node_id) in children.zip(region_children.iter().copied()) { - self.export_node_deep(child, child_node_id); + for child_id in ®ion_children { + self.export_node_deep(*child_id); } let signature = { - let inputs = self.export_type_row(&input_op.types); - let outputs = self.export_type_row(&output_op.types); + let inputs = self.export_type_row(input_types.unwrap()); + let outputs = self.export_type_row(output_types.unwrap()); Some(self.make_term(model::Term::FuncType { inputs, @@ -687,7 +682,7 @@ impl<'a> Context<'a> { kind: model::RegionKind::DataFlow, sources, targets, - children: region_children, + children: region_children.into_bump_slice(), meta: &[], // TODO: Export metadata signature, scope, @@ -705,58 +700,32 @@ impl<'a> Context<'a> { self.links.enter(region); } - let region_children = { - let children = self.hugr.children(node); - let mut region_children = - BumpVec::with_capacity_in(children.size_hint().0 - 1, self.bump); + let mut source = None; + let mut targets: &[_] = &[]; - // First export the children shallowly to allocate their IDs and register symbols. - for (i, child) in children.enumerate() { - // The second node is the exit block, which is not exported as a node itself. - if i == 1 { - continue; + let children = self.hugr.children(node); + let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 1, self.bump); + + for child in children { + match self.hugr.get_optype(child) { + OpType::ExitBlock(_) => { + targets = self.make_ports(child, Direction::Incoming, 1); } + _ => { + if let Some(child_id) = self.export_node_shallow(child) { + region_children.push(child_id); + } - region_children.push(self.export_node_shallow(child)); + if source.is_none() { + source = Some(self.get_link_index(child, IncomingPort::from(0))); + } + } } - - region_children.into_bump_slice() - }; - - let mut children_iter = self.hugr.children(node); - let mut region_children_iter = region_children.iter().copied(); - - // The first child is the entry block. - // We create a source port on the control flow region and connect it to the - // first input port of the exported entry block. - let source = { - let entry_block = children_iter.next().unwrap(); - let entry_node_id = region_children_iter.next().unwrap(); - - let OpType::DataflowBlock(_) = self.hugr.get_optype(entry_block) else { - panic!("expected a `DataflowBlock` node as the first child node"); - }; - - self.export_node_deep(entry_block, entry_node_id); - self.get_link_index(entry_block, IncomingPort::from(0)) - }; - - // The second child is the exit block. - // Contrary to the entry block, the exit block does not have a dataflow subgraph. - // We therefore do not export the block itself, but simply use its output ports - // as the target ports of the control flow region. - let exit_block = children_iter.next_back().unwrap(); - - let OpType::ExitBlock(_) = self.hugr.get_optype(exit_block) else { - panic!("expected an `ExitBlock` node as the second child node"); - }; - - // Export the remaining children of the node, except for the last one. - for (child, child_node_id) in children_iter.zip(region_children_iter) { - self.export_node_deep(child, child_node_id); } - let targets = self.make_ports(exit_block, Direction::Incoming, 1); + for child_id in ®ion_children { + self.export_node_deep(*child_id); + } // Get the signature of the control flow region. // This is the same as the signature of the parent node. @@ -773,9 +742,9 @@ impl<'a> Context<'a> { self.module.regions[region.index()] = model::Region { kind: model::RegionKind::ControlFlow, - sources: self.bump.alloc_slice_copy(&[source]), + sources: self.bump.alloc_slice_copy(&[source.unwrap()]), targets, - children: region_children, + children: region_children.into_bump_slice(), meta: &[], // TODO: Export metadata signature, scope, @@ -1029,6 +998,60 @@ impl<'a> Context<'a> { }) } + fn export_value(&mut self, value: &'a Value) -> model::TermId { + match value { + Value::Extension { e } => { + let json = match e.value().downcast_ref::() { + Some(custom) => serde_json::to_string(custom.value()).unwrap(), + None => serde_json::to_string(e.value()) + .expect("custom extension values should be serializable"), + }; + + let json = self.make_term(model::Term::Str(self.bump.alloc_str(&json))); + let runtime_type = self.export_type(&e.get_type()); + let extensions = self.export_ext_set(&e.extension_reqs()); + let args = self + .bump + .alloc_slice_copy(&[runtime_type, json, extensions]); + let symbol = self.resolve_symbol(TERM_JSON_CONST); + self.make_term(model::Term::ApplyFull { symbol, args }) + } + + Value::Function { hugr } => { + let outer_hugr = std::mem::replace(&mut self.hugr, hugr); + let outer_node_to_id = std::mem::take(&mut self.node_to_id); + + let region = match hugr.root_type() { + OpType::DFG(dfg) => { + let extensions = self.export_ext_set(&dfg.extension_delta()); + self.export_dfg(hugr.root(), extensions, model::ScopeClosure::Closed) + } + _ => panic!("Value::Function root must be a DFG"), + }; + + self.node_to_id = outer_node_to_id; + self.hugr = outer_hugr; + + self.make_term(model::Term::ConstFunc { region }) + } + + Value::Sum(sum) => { + let tag = sum.tag as _; + let mut values = BumpVec::with_capacity_in(sum.values.len(), self.bump); + + for value in &sum.values { + values.push(model::ListPart::Item(self.export_value(value))); + } + + let values = self.make_term(model::Term::List { + parts: values.into_bump_slice(), + }); + + self.make_term(model::Term::ConstAdt { tag, values }) + } + } + } + pub fn export_node_metadata( &mut self, metadata_map: &NodeMetadataMap, @@ -1037,14 +1060,14 @@ impl<'a> Context<'a> { for (name, value) in metadata_map { let name = self.bump.alloc_str(name); - let value = self.export_json(value); + let value = self.export_json_meta(value); meta.push(model::MetaItem { name, value }); } meta.into_bump_slice() } - pub fn export_json(&mut self, value: &serde_json::Value) -> model::TermId { + pub fn export_json_meta(&mut self, value: &serde_json::Value) -> model::TermId { let value = serde_json::to_string(value).expect("json values are always serializable"); let value = self.make_term(model::Term::Str(self.bump.alloc_str(&value))); let value = self.bump.alloc_slice_copy(&[value]); diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 72fe8601a..6a08b4a78 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -10,9 +10,10 @@ use crate::{ extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, hugr::{HugrMut, IdentList}, ops::{ - AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, DataflowBlock, ExitBlock, - FuncDecl, FuncDefn, Input, LoadFunction, Module, OpType, OpaqueOp, Output, Tag, TailLoop, - CFG, DFG, + constant::{CustomConst, CustomSerialized, OpaqueValue}, + AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, + ExitBlock, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, Module, OpType, OpaqueOp, + Output, Tag, TailLoop, Value, CFG, DFG, }, types::{ type_param::TypeParam, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, @@ -28,6 +29,7 @@ use smol_str::{SmolStr, ToSmolStr}; use thiserror::Error; const TERM_JSON: &str = "prelude.json"; +const TERM_JSON_CONST: &str = "prelude.const-json"; /// Error during import. #[derive(Debug, Clone, Error)] @@ -172,7 +174,7 @@ impl<'a> Context<'a> { for meta_item in node_data.meta { // TODO: For now we expect all metadata to be JSON since this is how // it is handled in `hugr-core`. - let value = self.import_json_value(meta_item.value)?; + let value = self.import_json_meta(meta_item.value)?; self.hugr.set_metadata(node, meta_item.name, value); } @@ -442,12 +444,6 @@ impl<'a> Context<'a> { let node = self.make_node(node_id, optype, parent)?; - match node_data.regions { - [] => {} - [region] => self.import_dfg_region(node_id, *region, node)?, - _ => return Err(error_unsupported!("multiple regions in custom operation")), - } - Ok(Some(node)) } @@ -508,6 +504,36 @@ impl<'a> Context<'a> { model::Operation::DeclareConstructor { .. } => Ok(None), model::Operation::DeclareOperation { .. } => Ok(None), + + model::Operation::Const { value } => { + let signature = node_data + .signature + .ok_or_else(|| error_uninferred!("node signature"))?; + let (_, outputs, _) = self.get_func_type(signature)?; + let outputs = self.import_closed_list(outputs)?; + let output = outputs + .first() + .ok_or(model::ModelError::TypeError(signature))?; + let datatype = self.import_type(*output)?; + + let imported_value = self.import_value(value, *output)?; + + let load_const_node = self.make_node( + node_id, + OpType::LoadConstant(LoadConstant { + datatype: datatype.clone(), + }), + parent, + )?; + + let const_node = self + .hugr + .add_node_with_parent(parent, OpType::Const(Const::new(imported_value))); + + self.hugr.connect(const_node, 0, load_const_node, 0); + + Ok(Some(load_const_node)) + } } } @@ -897,7 +923,7 @@ impl<'a> Context<'a> { model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), - model::Term::Quote { .. } => Err(error_unsupported!("`(quote ...)` as `TypeParam`")), + model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")), model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), model::Term::ListType { item_type } => { @@ -918,9 +944,9 @@ impl<'a> Context<'a> { | model::Term::ExtSet { .. } | model::Term::Adt { .. } | model::Term::Control { .. } - | model::Term::NonLinearConstraint { .. } => { - Err(model::ModelError::TypeError(term_id).into()) - } + | model::Term::NonLinearConstraint { .. } + | model::Term::ConstFunc { .. } + | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), model::Term::ControlType => { Err(error_unsupported!("type of control types as `TypeParam`")) @@ -959,9 +985,6 @@ impl<'a> Context<'a> { arg: value.to_string(), }), - model::Term::Quote { .. } => Ok(TypeArg::Type { - ty: self.import_type(term_id)?, - }), model::Term::Nat(value) => Ok(TypeArg::BoundedNat { n: *value }), model::Term::ExtSet { .. } => Ok(TypeArg::Extensions { es: self.import_extension_set(term_id)?, @@ -976,6 +999,11 @@ impl<'a> Context<'a> { model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), + model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")), + model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")), + model::Term::ConstFunc { .. } => { + Err(error_unsupported!("function constant as `TypeArg`")) + } model::Term::FuncType { .. } | model::Term::Adt { .. } @@ -1045,12 +1073,12 @@ impl<'a> Context<'a> { let (extension, id) = self.import_custom_name(name)?; let extension_ref = - self.extensions.get(&extension.to_string()).ok_or_else(|| { - ImportError::Extension { + self.extensions + .get(&extension) + .ok_or_else(|| ImportError::Extension { missing_ext: extension.clone(), available: self.extensions.ids().cloned().collect(), - } - })?; + })?; Ok(TypeBase::new_extension(CustomType::new( id, @@ -1090,16 +1118,16 @@ impl<'a> Context<'a> { | model::Term::StaticType | model::Term::Type | model::Term::Constraint - | model::Term::Quote { .. } + | model::Term::Const { .. } | model::Term::Str(_) | model::Term::ExtSet { .. } | model::Term::List { .. } | model::Term::Control { .. } | model::Term::ControlType | model::Term::Nat(_) - | model::Term::NonLinearConstraint { .. } => { - Err(model::ModelError::TypeError(term_id).into()) - } + | model::Term::NonLinearConstraint { .. } + | model::Term::ConstFunc { .. } + | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), } } @@ -1234,7 +1262,7 @@ impl<'a> Context<'a> { } } - fn import_json_value( + fn import_json_meta( &mut self, term_id: model::TermId, ) -> Result { @@ -1263,6 +1291,116 @@ impl<'a> Context<'a> { Ok(json_value) } + + fn import_value( + &mut self, + term_id: model::TermId, + type_id: model::TermId, + ) -> Result { + let term_data = self.get_term(term_id)?; + + match term_data { + model::Term::Wildcard => Err(error_uninferred!("wildcard")), + model::Term::Apply { .. } => { + Err(error_uninferred!("application with implicit parameters")) + } + model::Term::Var(_) => Err(error_unsupported!("constant value containing a variable")), + + model::Term::ApplyFull { symbol, args } => { + let symbol_name = self.get_symbol_name(*symbol)?; + + if symbol_name == TERM_JSON_CONST { + let value = args.get(1).ok_or(model::ModelError::TypeError(term_id))?; + + let model::Term::Str(json) = self.get_term(*value)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + // We attempt to deserialize as the custom const directly. + // This might fail due to the custom const struct not being included when + // this code was compiled; in that case, we fall back to the serialized form. + let value: Option> = serde_json::from_str(json).ok(); + + if let Some(value) = value { + let opaque_value = OpaqueValue::from(value); + return Ok(Value::Extension { e: opaque_value }); + } else { + let runtime_type = + args.first().ok_or(model::ModelError::TypeError(term_id))?; + let runtime_type = self.import_type(*runtime_type)?; + + let extensions = + args.get(2).ok_or(model::ModelError::TypeError(term_id))?; + let extensions = self.import_extension_set(*extensions)?; + + let value: serde_json::Value = serde_json::from_str(json) + .map_err(|_| model::ModelError::TypeError(term_id))?; + let custom_const = CustomSerialized::new(runtime_type, value, extensions); + let opaque_value = OpaqueValue::new(custom_const); + return Ok(Value::Extension { e: opaque_value }); + } + } + + Err(error_unsupported!("constant value that is not JSON data")) + // TODO: This should ultimately include the following cases: + // - function definitions + // - custom constructors for values + } + + model::Term::StaticType + | model::Term::Constraint + | model::Term::Const { .. } + | model::Term::List { .. } + | model::Term::ListType { .. } + | model::Term::Str(_) + | model::Term::StrType + | model::Term::Nat(_) + | model::Term::NatType + | model::Term::ExtSet { .. } + | model::Term::ExtSetType + | model::Term::Adt { .. } + | model::Term::FuncType { .. } + | model::Term::Control { .. } + | model::Term::ControlType + | model::Term::Type + | model::Term::NonLinearConstraint { .. } => { + Err(model::ModelError::TypeError(term_id).into()) + } + + model::Term::ConstFunc { .. } => Err(error_unsupported!("constant function value")), + + model::Term::ConstAdt { tag, values } => { + let model::Term::Adt { variants } = self.get_term(type_id)? else { + return Err(model::ModelError::TypeError(term_id).into()); + }; + + let values = self.import_closed_list(*values)?; + let variants = self.import_closed_list(*variants)?; + + let variant = variants + .get(*tag as usize) + .ok_or(model::ModelError::TypeError(term_id))?; + let variant = self.import_closed_list(*variant)?; + + let items = values + .iter() + .zip(variant.iter()) + .map(|(value, typ)| self.import_value(*value, *typ)) + .collect::, _>>()?; + + let typ = { + // TODO: Import as a `SumType` directly and avoid the copy. + let typ: Type = self.import_type(type_id)?; + match typ.as_type_enum() { + TypeEnum::Sum(sum) => sum.clone(), + _ => unreachable!(), + } + }; + + Ok(Value::sum(*tag as _, items, typ).unwrap()) + } + } + } } /// Information about a local variable. diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index ca67024f4..c009fb848 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -309,6 +309,12 @@ impl From for OpaqueValue { } } +impl From> for OpaqueValue { + fn from(value: Box) -> Self { + Self { v: value } + } +} + impl PartialEq for OpaqueValue { fn eq(&self, other: &Self) -> bool { self.value().equal_consts(other.value()) diff --git a/hugr-core/tests/model.rs b/hugr-core/tests/model.rs index 7ae4010f4..ea5b69b22 100644 --- a/hugr-core/tests/model.rs +++ b/hugr-core/tests/model.rs @@ -67,3 +67,10 @@ pub fn test_roundtrip_constraints() { "../../hugr-model/tests/fixtures/model-constraints.edn" ))); } + +#[test] +pub fn test_roundtrip_const() { + insta::assert_snapshot!(roundtrip(include_str!( + "../../hugr-model/tests/fixtures/model-const.edn" + ))); +} diff --git a/hugr-core/tests/snapshots/model__roundtrip_add.snap b/hugr-core/tests/snapshots/model__roundtrip_add.snap index 7ffec5ef9..288bce4ba 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_add.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_add.snap @@ -15,13 +15,13 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-add. (dfg [%0 %1] [%2] (signature - (fn + (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) ((@ arithmetic.int.iadd) [%0 %1] [%2] (signature - (fn + (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_alias.snap b/hugr-core/tests/snapshots/model__roundtrip_alias.snap index 27fdd4740..de4d36952 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_alias.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_alias.snap @@ -10,4 +10,4 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-alia (define-alias local.int type (@ arithmetic.int.types.int)) -(define-alias local.endo type (fn [] [] (ext))) +(define-alias local.endo type (-> [] [] (ext))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_call.snap b/hugr-core/tests/snapshots/model__roundtrip_call.snap index 2b37b5a20..7ade416d1 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_call.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_call.snap @@ -28,38 +28,39 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-call (dfg [%0] [%1] (signature - (fn + (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))) (call (@ example.callee (ext)) [%0] [%1] (signature - (fn + (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))))) (define-func example.load [] - [(fn + [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext) (dfg + [] [%0] (signature - (fn + (-> [] - [(fn + [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext))) - (load-func (@ example.caller) + (load-func (@ example.caller) [] [%0] (signature - (fn + (-> [] - [(fn + [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index e39f0d37d..d3ed92bc7 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -9,21 +9,21 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. [?0] [?0] (ext) (dfg [%0] [%1] - (signature (fn [?0] [?0] (ext))) + (signature (-> [?0] [?0] (ext))) (cfg [%0] [%1] - (signature (fn [?0] [?0] (ext))) + (signature (-> [?0] [?0] (ext))) (cfg - [%4] [%8] - (signature (fn [?0] [?0] (ext))) - (block [%4] [%5] - (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) + [%2] [%3] + (signature (-> [?0] [?0] (ext))) + (block [%2] [%6] + (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg - [%2] [%3] - (signature (fn [?0] [(adt [[?0]])] (ext))) - (tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0]])] (ext)))))) - (block [%5] [%8] - (signature (fn [(ctrl [?0])] [(ctrl [?0])] (ext))) + [%4] [%5] + (signature (-> [?0] [(adt [[?0]])] (ext))) + (tag 0 [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext)))))) + (block [%6] [%3] + (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg - [%6] [%7] - (signature (fn [?0] [(adt [[?0]])] (ext))) - (tag 0 [%6] [%7] (signature (fn [?0] [(adt [[?0]])] (ext)))))))))) + [%7] [%8] + (signature (-> [?0] [(adt [[?0]])] (ext))) + (tag 0 [%7] [%8] (signature (-> [?0] [(adt [[?0]])] (ext)))))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_cond.snap b/hugr-core/tests/snapshots/model__roundtrip_cond.snap index 92ab0cb4d..45c654ac4 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cond.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cond.snap @@ -15,34 +15,34 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cond (dfg [%0 %1] [%2] (signature - (fn + (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (cond [%0 %1] [%2] (signature - (fn + (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (dfg [%3] [%3] (signature - (fn + (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))) (dfg [%4] [%5] (signature - (fn + (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) ((@ arithmetic.int.ineg) [%4] [%5] (signature - (fn + (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_const.snap b/hugr-core/tests/snapshots/model__roundtrip_const.snap new file mode 100644 index 000000000..6b6bc1464 --- /dev/null +++ b/hugr-core/tests/snapshots/model__roundtrip_const.snap @@ -0,0 +1,79 @@ +--- +source: hugr-core/tests/model.rs +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-const.edn\"))" +--- +(hugr 0) + +(import prelude.const-json) + +(import arithmetic.float.types.float64) + +(define-func example.bools + [] [(adt [[] []]) (adt [[] []])] (ext) + (dfg + [] [%0 %1] + (signature (-> [] [(adt [[] []]) (adt [[] []])] (ext))) + (const (tag 0 []) [] [%0] (signature (-> [] [(adt [[] []])] (ext)))) + (const (tag 1 []) [] [%1] (signature (-> [] [(adt [[] []])] (ext)))))) + +(define-func example.make-pair + [] + [(adt + [[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]])] + (ext) + (dfg + [] [%0] + (signature + (-> + [] + [(adt + [[(@ arithmetic.float.types.float64) + (@ arithmetic.float.types.float64)]])] + (ext))) + (const + (tag + 0 + [(@ + prelude.const-json + (@ arithmetic.float.types.float64) + "{\"c\":\"ConstF64\",\"v\":{\"value\":2.0}}" + (ext arithmetic.float.types)) + (@ + prelude.const-json + (@ arithmetic.float.types.float64) + "{\"c\":\"ConstF64\",\"v\":{\"value\":3.0}}" + (ext arithmetic.float.types))]) + [] [%0] + (signature + (-> + [] + [(adt + [[(@ arithmetic.float.types.float64) + (@ arithmetic.float.types.float64)]])] + (ext)))))) + +(define-func example.f64 + [] [(@ arithmetic.float.types.float64)] (ext) + (dfg + [] [%0 %1] + (signature + (-> + [] + [(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)] + (ext))) + (const + (@ + prelude.const-json + (@ arithmetic.float.types.float64) + "{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}" + (ext arithmetic.float.types)) + [] [%0] + (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))) + (const + (@ + prelude.const-json + (@ arithmetic.float.types.float64) + "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}" + (ext)) + [] [%1] + (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_loop.snap b/hugr-core/tests/snapshots/model__roundtrip_loop.snap index a513318ae..a7c21dfec 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_loop.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_loop.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-loop.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-loop.edn\"))" --- (hugr 0) @@ -9,11 +9,11 @@ expression: "roundtrip(include_str!(\"fixtures/model-loop.edn\"))" [?0] [?0] (ext) (dfg [%0] [%1] - (signature (fn [?0] [?0] (ext))) + (signature (-> [?0] [?0] (ext))) (tail-loop [%0] [%1] - (signature (fn [?0] [?0] (ext))) + (signature (-> [?0] [?0] (ext))) (dfg [%2] [%3] - (signature (fn [?0] [(adt [[?0] [?0]])] (ext))) - (tag 0 [%2] [%3] (signature (fn [?0] [(adt [[?0] [?0]])] (ext)))))))) + (signature (-> [?0] [(adt [[?0] [?0]])] (ext))) + (tag 0 [%2] [%3] (signature (-> [?0] [(adt [[?0] [?0]])] (ext)))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_params.snap b/hugr-core/tests/snapshots/model__roundtrip_params.snap index ab2b98d8c..214cb9755 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_params.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_params.snap @@ -1,6 +1,6 @@ --- source: hugr-core/tests/model.rs -expression: "roundtrip(include_str!(\"fixtures/model-params.edn\"))" +expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-params.edn\"))" --- (hugr 0) @@ -8,4 +8,4 @@ expression: "roundtrip(include_str!(\"fixtures/model-params.edn\"))" (forall ?0 type) (forall ?1 type) [?0 ?1] [?1 ?0] (ext) - (dfg [%0 %1] [%1 %0] (signature (fn [?0 ?1] [?1 ?0] (ext))))) + (dfg [%0 %1] [%1 %0] (signature (-> [?0 ?1] [?1 ?0] (ext))))) diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 259abfa69..4a2627a69 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -83,8 +83,7 @@ impl<'c> Emission<'c> { /// That function must take no arguments and return an `i64`. pub fn exec_i64(&self, entry: impl AsRef) -> Result { let gv = self.exec_impl(entry)?; - let x: u64 = gv.as_int(true).try_into().unwrap(); - Ok(x as i64) + Ok(gv.as_int(true) as i64) } /// JIT and execute the function named `entry` in the inner module. diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index 37f54241e..55e4cc1f5 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -346,7 +346,7 @@ mod test { .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()]) .unwrap(); let outputs = hugr_builder - .add_dataflow_op(ext_op, input_wires.into_iter()) + .add_dataflow_op(ext_op, input_wires) .unwrap() .outputs(); hugr_builder.finish_with_outputs(outputs).unwrap() diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index b3bb1f0f2..060e8af3b 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -55,6 +55,7 @@ struct Operation { constructorDecl @15 :ConstructorDecl; operationDecl @16 :OperationDecl; import @17 :Text; + const @18 :TermId; } struct FuncDefn { @@ -140,7 +141,7 @@ struct Term { } apply @5 :Apply; applyFull @6 :ApplyFull; - quote @7 :TermId; + const @7 :Const; list @8 :ListTerm; listType @9 :TermId; string @10 :Text; @@ -154,6 +155,8 @@ struct Term { control @18 :TermId; controlType @19 :Void; nonLinearConstraint @20 :TermId; + constFunc @22 :RegionId; + constAdt @23 :ConstAdt; } struct Apply { @@ -188,11 +191,21 @@ struct Term { } } + struct ConstAdt { + tag @0 :UInt16; + values @1 :TermId; + } + struct FuncType { inputs @0 :TermId; outputs @1 :TermId; extensions @2 :TermId; } + + struct Const { + type @0 :TermId; + extensions @1 :TermId; + } } struct Param { diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index b14ca4482..2ea0e7742 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -198,6 +198,9 @@ fn read_operation<'a>( Which::Import(name) => model::Operation::Import { name: bump.alloc_str(name?.to_str()?), }, + Which::Const(value) => model::Operation::Const { + value: model::TermId(value), + }, }) } @@ -274,9 +277,13 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult model::Term::ApplyFull { symbol, args } } - Which::Quote(r#type) => model::Term::Quote { - r#type: model::TermId(r#type), - }, + Which::Const(reader) => { + let reader = reader?; + model::Term::Const { + r#type: model::TermId(reader.get_type()), + extensions: model::TermId(reader.get_extensions()), + } + } Which::List(reader) => { let reader = reader?; @@ -317,6 +324,17 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult Which::NonLinearConstraint(term) => model::Term::NonLinearConstraint { term: model::TermId(term), }, + + Which::ConstFunc(region) => model::Term::ConstFunc { + region: model::RegionId(region), + }, + + Which::ConstAdt(reader) => { + let reader = reader?; + let tag = reader.get_tag(); + let values = model::TermId(reader.get_values()); + model::Term::ConstAdt { tag, values } + } }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index ea495db54..3a1e1beba 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -103,6 +103,8 @@ fn write_operation(mut builder: hugr_capnp::operation::Builder, operation: &mode } model::Operation::Invalid => builder.set_invalid(()), + + model::Operation::Const { value } => builder.set_const(value.0), } } @@ -161,7 +163,11 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::NatType => builder.set_nat_type(()), model::Term::ExtSetType => builder.set_ext_set_type(()), model::Term::Adt { variants } => builder.set_adt(variants.0), - model::Term::Quote { r#type } => builder.set_quote(r#type.0), + model::Term::Const { r#type, extensions } => { + let mut builder = builder.init_const(); + builder.set_type(r#type.0); + builder.set_extensions(extensions.0); + } model::Term::Control { values } => builder.set_control(values.0), model::Term::ControlType => builder.set_control_type(()), @@ -201,6 +207,16 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { model::Term::NonLinearConstraint { term } => { builder.set_non_linear_constraint(term.0); } + + model::Term::ConstFunc { region } => { + builder.set_const_func(region.0); + } + + model::Term::ConstAdt { tag, values } => { + let mut builder = builder.init_const_adt(); + builder.set_tag(*tag); + builder.set_values(values.0); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index ad3733079..f9da742e9 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -375,6 +375,12 @@ pub enum Operation<'a> { /// The name of the symbol to be imported. name: &'a str, }, + + /// Create a constant value. + Const { + /// The term that describes how to construct the constant value. + value: TermId, + }, } impl<'a> Operation<'a> { @@ -559,14 +565,18 @@ pub enum Term<'a> { args: &'a [TermId], }, - /// Quote a runtime type as a static type. + /// Type for a constant runtime value. /// - /// `(quote T) : static` where `T : type`. - Quote { - /// The runtime type to be quoted. + /// `(const T) : static` where `T : type`. + Const { + /// The runtime type of the constant value. /// /// **Type:** `type` r#type: TermId, + /// The extension set required to be present in order to use the constant value. + /// + /// **Type:** `ext-set` + extensions: TermId, }, /// A list. May include individual items or other lists to be spliced in. @@ -662,6 +672,20 @@ pub enum Term<'a> { /// The runtime type that must be copyable and discardable. term: TermId, }, + + /// A constant anonymous function. + ConstFunc { + /// The body of the constant anonymous function. + region: RegionId, + }, + + /// A constant value for an algebraic data type. + ConstAdt { + /// The tag of the variant. + tag: u16, + /// The values of the variant. + values: TermId, + }, } /// A part of a list term. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 4fd34f223..3d37b9878 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -35,6 +35,7 @@ node = { | node_cond | node_tag | node_import + | node_const | node_custom } @@ -53,6 +54,7 @@ node_tail_loop = { "(" ~ "tail-loop" ~ port_lists? ~ signature? ~ meta* node_cond = { "(" ~ "cond" ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_tag = { "(" ~ "tag" ~ tag ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } node_import = { "(" ~ "import" ~ symbol ~ meta* ~ ")" } +node_const = { "(" ~ "const" ~ term ~ port_lists? ~ signature? ~ meta* ~ ")" } node_custom = { "(" ~ (term_apply | term_apply_full) ~ port_lists? ~ signature? ~ meta* ~ region* ~ ")" } signature = { "(" ~ "signature" ~ term ~ ")" } @@ -77,7 +79,7 @@ term = { | term_static | term_constraint | term_var - | term_quote + | term_const | term_list | term_list_type | term_str @@ -93,6 +95,8 @@ term = { | term_apply_full | term_apply | term_non_linear + | term_const_func + | term_const_adt } term_wildcard = { "_" } @@ -102,7 +106,7 @@ term_constraint = { "constraint" } term_var = { "?" ~ identifier } term_apply_full = { ("(" ~ "@" ~ symbol ~ term* ~ ")") } term_apply = { symbol | ("(" ~ symbol ~ term* ~ ")") } -term_quote = { "(" ~ "quote" ~ term ~ ")" } +term_const = { "(" ~ "const" ~ term ~ term ~ ")" } term_list = { "[" ~ (spliced_term | term)* ~ "]" } term_list_type = { "(" ~ "list" ~ term ~ ")" } term_str = { string } @@ -112,9 +116,11 @@ term_nat_type = { "nat" } term_ext_set = { "(" ~ "ext" ~ (spliced_term | ext_name)* ~ ")" } term_ext_set_type = { "ext-set" } term_adt = { "(" ~ "adt" ~ term ~ ")" } -term_func_type = { "(" ~ "fn" ~ term ~ term ~ term ~ ")" } +term_func_type = { "(" ~ "->" ~ term ~ term ~ term ~ ")" } term_ctrl = { "(" ~ "ctrl" ~ term ~ ")" } term_ctrl_type = { "ctrl" } term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } +term_const_func = { "(" ~ "fn" ~ term ~ ")" } +term_const_adt = { "(" ~ "tag" ~ tag ~ term* ~ ")" } spliced_term = { term ~ "..." } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 4ad77d914..34435b2a3 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -165,9 +165,10 @@ impl<'a> ParseContext<'a> { } } - Rule::term_quote => { + Rule::term_const => { let r#type = self.parse_term(inner.next().unwrap())?; - Term::Quote { r#type } + let extensions = self.parse_term(inner.next().unwrap())?; + Term::Const { r#type, extensions } } Rule::term_list => { @@ -250,6 +251,17 @@ impl<'a> ParseContext<'a> { Term::NonLinearConstraint { term } } + Rule::term_const_func => { + let region = self.parse_region(inner.next().unwrap(), ScopeClosure::Closed)?; + Term::ConstFunc { region } + } + + Rule::term_const_adt => { + let tag = inner.next().unwrap().as_str().parse().unwrap(); + let values = self.parse_term(inner.next().unwrap())?; + Term::ConstAdt { tag, values } + } + r => unreachable!("term: {:?}", r), }; @@ -584,6 +596,23 @@ impl<'a> ParseContext<'a> { } } + Rule::node_const => { + let value = self.parse_term(inner.next().unwrap())?; + let inputs = self.parse_port_list(&mut inner)?; + let outputs = self.parse_port_list(&mut inner)?; + let signature = self.parse_signature(&mut inner)?; + let meta = self.parse_meta(&mut inner)?; + Node { + operation: Operation::Const { value }, + inputs, + outputs, + params: &[], + regions: &[], + meta, + signature, + } + } + _ => unreachable!(), }; diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index ba7874e45..430f07e99 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -370,6 +370,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(*name); this.print_meta(node_data.meta) } + + Operation::Const { value } => { + this.print_text("const"); + this.print_term(*value)?; + this.print_port_lists(node_data.inputs, node_data.outputs)?; + this.print_signature(node_data.signature)?; + this.print_meta(node_data.meta) + } }) } @@ -422,7 +430,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { first: &'a [LinkIndex], second: &'a [LinkIndex], ) -> PrintResult<()> { - if !first.is_empty() && !second.is_empty() { + if !first.is_empty() || !second.is_empty() { self.print_group(|this| { this.print_port_list(first)?; this.print_port_list(second) @@ -520,9 +528,10 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { Ok(()) }), - Term::Quote { r#type } => self.print_parens(|this| { - this.print_text("quote"); - this.print_term(*r#type) + Term::Const { r#type, extensions } => self.print_parens(|this| { + this.print_text("const"); + this.print_term(*r#type)?; + this.print_term(*extensions) }), Term::List { .. } => self.print_brackets(|this| this.print_list_parts(term_id)), Term::ListType { item_type } => self.print_parens(|this| { @@ -563,7 +572,7 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { outputs, extensions, } => self.print_parens(|this| { - this.print_text("fn"); + this.print_text("->"); this.print_term(*inputs)?; this.print_term(*outputs)?; this.print_term(*extensions) @@ -580,6 +589,15 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text("nonlinear"); this.print_term(*term) }), + Term::ConstFunc { region } => self.print_parens(|this| { + this.print_text("fn"); + this.print_region(*region) + }), + Term::ConstAdt { tag, values } => self.print_parens(|this| { + this.print_text("tag"); + this.print_text(tag.to_string()); + this.print_term(*values) + }), } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 80157c23e..6a00a6b35 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -63,3 +63,8 @@ pub fn test_constraints() { pub fn test_lists() { binary_roundtrip(include_str!("fixtures/model-lists.edn")); } + +#[test] +pub fn test_const() { + binary_roundtrip(include_str!("fixtures/model-const.edn")); +} diff --git a/hugr-model/tests/fixtures/model-add.edn b/hugr-model/tests/fixtures/model-add.edn index f7783cb41..ed8476ea9 100644 --- a/hugr-model/tests/fixtures/model-add.edn +++ b/hugr-model/tests/fixtures/model-add.edn @@ -7,7 +7,7 @@ (dfg [%0 %1] [%2] - (signature (fn [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) ((@ arithmetic.int.iadd) [%0 %1] [%2] - (signature (fn [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(@ arithmetic.int.types.int) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) ))) diff --git a/hugr-model/tests/fixtures/model-alias.edn b/hugr-model/tests/fixtures/model-alias.edn index 9783b3dbd..2998410ad 100644 --- a/hugr-model/tests/fixtures/model-alias.edn +++ b/hugr-model/tests/fixtures/model-alias.edn @@ -4,4 +4,4 @@ (define-alias local.int type (@ arithmetic.int.types.int)) -(define-alias local.endo type (fn [] [] (ext))) +(define-alias local.endo type (-> [] [] (ext))) diff --git a/hugr-model/tests/fixtures/model-call.edn b/hugr-model/tests/fixtures/model-call.edn index 87c6f7a3a..839fb278a 100644 --- a/hugr-model/tests/fixtures/model-call.edn +++ b/hugr-model/tests/fixtures/model-call.edn @@ -11,14 +11,14 @@ (meta doc.title (prelude.json "\"Caller\"")) (meta doc.description (prelude.json "\"This defines a function that calls the function which we declared earlier.\"")) (dfg [%3] [%4] - (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (call (@ example.callee (ext)) [%3] [%4] - (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))) + (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))) (define-func example.load - [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext) + [] [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext) (dfg [] [%5] - (signature (fn [] [(fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext))) + (signature (-> [] [(-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext arithmetic.int))] (ext))) (load-func (@ example.caller) [] [%5]))) diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index 92ae19441..3987450f9 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -4,14 +4,14 @@ (forall ?a type) [?a] [?a] (ext) (dfg [%0] [%1] - (signature (fn [?a] [?a] (ext))) + (signature (-> [?a] [?a] (ext))) (cfg [%0] [%1] - (signature (fn [?a] [?a] (ext))) + (signature (-> [?a] [?a] (ext))) (cfg [%2] [%4] - (signature (fn [(ctrl [?a])] [(ctrl [?a])] (ext))) + (signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext))) (block [%2] [%4] - (signature (fn [(ctrl [?a])] [(ctrl [?a])] (ext))) + (signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext))) (dfg [%5] [%6] - (signature (fn [?a] [(adt [[?a]])] (ext))) + (signature (-> [?a] [(adt [[?a]])] (ext))) (tag 0 [%5] [%6] - (signature (fn [?a] [(adt [[?a]])] (ext)))))))))) + (signature (-> [?a] [(adt [[?a]])] (ext)))))))))) diff --git a/hugr-model/tests/fixtures/model-cond.edn b/hugr-model/tests/fixtures/model-cond.edn index aa1ecef7d..d6b84d9fa 100644 --- a/hugr-model/tests/fixtures/model-cond.edn +++ b/hugr-model/tests/fixtures/model-cond.edn @@ -4,12 +4,12 @@ [(@ arithmetic.int.types.int)] (ext) (dfg [%0 %1] [%2] - (signature (fn [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (cond [%0 %1] [%2] - (signature (fn [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(adt [[] []]) (@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) (dfg [%3] [%3] - (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))) + (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))) (dfg [%4] [%5] - (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) + (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext))) ((@ arithmetic.int.ineg) [%4] [%5] - (signature (fn [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))))) + (signature (-> [(@ arithmetic.int.types.int)] [(@ arithmetic.int.types.int)] (ext)))))))) diff --git a/hugr-model/tests/fixtures/model-const.edn b/hugr-model/tests/fixtures/model-const.edn new file mode 100644 index 000000000..025e77043 --- /dev/null +++ b/hugr-model/tests/fixtures/model-const.edn @@ -0,0 +1,50 @@ +(hugr 0) + +(define-func example.bools + [] + [(adt [[] []]) (adt [[] []])] + (ext) + (dfg [] [%false %true] + (signature (-> [] [(adt [[] []]) (adt [[] []])] (ext))) + (const (tag 0 []) [] [%false] + (signature (-> [] [(adt [[] []])] (ext)))) + (const (tag 1 []) [] [%true] + (signature (-> [] [(adt [[] []])] (ext)))))) + +(define-func example.make-pair + [] + [(adt [[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]])] + (ext) + (dfg [] [%pair] + (signature + (-> + [] + [(adt [[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]])] + (ext))) + (const + (tag + 0 + [(@ prelude.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstF64\",\"v\":{\"value\":2.0}}" (ext)) + (@ prelude.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstF64\",\"v\":{\"value\":3.0}}" (ext))]) + [] [%pair] + (signature + (-> + [] + [(adt [[(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)]])] + (ext)))))) + +(define-func example.f64 + [] + [(@ arithmetic.float.types.float64)] + (ext) + (dfg [] [%0 %1] + (signature (-> [] [(@ arithmetic.float.types.float64) (@ arithmetic.float.types.float64)] (ext))) + (const + (@ prelude.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstF64\",\"v\":{\"value\":1.0}}" (ext)) + [] [%0] + (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))) + ; The following const is to test that import/export can deal with unknown constants. + (const + (@ prelude.const-json (@ arithmetic.float.types.float64) "{\"c\":\"ConstUnknown\",\"v\":{\"value\":1.0}}" (ext)) + [] [%1] + (signature (-> [] [(@ arithmetic.float.types.float64)] (ext)))))) diff --git a/hugr-model/tests/fixtures/model-decl-exts.edn b/hugr-model/tests/fixtures/model-decl-exts.edn index c38c78cdf..bee0d68fc 100644 --- a/hugr-model/tests/fixtures/model-decl-exts.edn +++ b/hugr-model/tests/fixtures/model-decl-exts.edn @@ -9,5 +9,5 @@ (declare-operation array.Init (param ?t type) (param ?n nat) - (fn [?t] [(array.Array ?t ?n)] (ext array)) + (-> [?t] [(array.Array ?t ?n)] (ext array)) (meta docs.description "Initialize an array of size ?n with copies of a default value.")) diff --git a/hugr-model/tests/fixtures/model-lists.edn b/hugr-model/tests/fixtures/model-lists.edn index 1385a0e2a..db84ffe72 100644 --- a/hugr-model/tests/fixtures/model-lists.edn +++ b/hugr-model/tests/fixtures/model-lists.edn @@ -4,7 +4,7 @@ (forall ?inputs (list type)) (forall ?outputs (list type)) (forall ?exts ext-set) - (fn [(fn ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) + (-> [(-> ?inputs ?outputs ?exts) ?inputs ...] ?outputs ?exts)) (declare-operation core.compose-parallel (forall ?inputs-0 (list type)) @@ -12,9 +12,9 @@ (forall ?outputs-0 (list type)) (forall ?outputs-1 (list type)) (forall ?exts ext-set) - (fn - [(fn ?inputs-0 ?outputs-0 ?exts) - (fn ?inputs-1 ?outputs-1 ?exts) + (-> + [(-> ?inputs-0 ?outputs-0 ?exts) + (-> ?inputs-1 ?outputs-1 ?exts) ?inputs-0 ... ?inputs-1 ...] [?outputs-0 ... ?outputs-1 ...] diff --git a/hugr-model/tests/fixtures/model-loop.edn b/hugr-model/tests/fixtures/model-loop.edn index 5df4b2a87..f2c49d9d6 100644 --- a/hugr-model/tests/fixtures/model-loop.edn +++ b/hugr-model/tests/fixtures/model-loop.edn @@ -4,10 +4,10 @@ (forall ?a type) [?a] [?a] (ext) (dfg [%0] [%1] - (signature (fn [?a] [?a] (ext))) + (signature (-> [?a] [?a] (ext))) (tail-loop [%0] [%1] - (signature (fn [?a] [?a] (ext))) + (signature (-> [?a] [?a] (ext))) (dfg [%2] [%3] - (signature (fn [?a] [(adt [[?a] [?a]])] (ext))) + (signature (-> [?a] [(adt [[?a] [?a]])] (ext))) (tag 0 [%2] [%3] - (signature (fn [?a] [(adt [[?a] [?a]])] (ext)))))))) + (signature (-> [?a] [(adt [[?a] [?a]])] (ext)))))))) diff --git a/hugr-model/tests/fixtures/model-params.edn b/hugr-model/tests/fixtures/model-params.edn index 171860cae..6f8554745 100644 --- a/hugr-model/tests/fixtures/model-params.edn +++ b/hugr-model/tests/fixtures/model-params.edn @@ -6,4 +6,4 @@ (forall ?b type) [?a ?b] [?b ?a] (ext) (dfg [%a %b] [%b %a] - (signature (fn [?a ?b] [?b ?a] (ext))))) + (signature (-> [?a ?b] [?b ?a] (ext))))) diff --git a/hugr-model/tests/snapshots/text__declarative_extensions.snap b/hugr-model/tests/snapshots/text__declarative_extensions.snap index d26909912..852c5bdac 100644 --- a/hugr-model/tests/snapshots/text__declarative_extensions.snap +++ b/hugr-model/tests/snapshots/text__declarative_extensions.snap @@ -13,6 +13,6 @@ expression: "roundtrip(include_str!(\"fixtures/model-decl-exts.edn\"))" (declare-operation array.Init (param ?t type) (param ?n nat) - (fn [?t] [(array.Array ?t ?n)] (ext array)) + (-> [?t] [(array.Array ?t ?n)] (ext array)) (meta docs.description "Initialize an array of size ?n with copies of a default value."))