diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 97c971822..e1674afac 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -3,13 +3,17 @@ //! TODO: YAML declaration and parsing. This should be similar to a plugin //! system (outside the `types` module), which also parses nested [`OpDef`]s. +use itertools::Itertools; pub use semver::Version; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::btree_map; use std::collections::{BTreeMap, BTreeSet}; -use std::fmt::{Debug, Display, Formatter}; +use std::fmt::Debug; use std::mem; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Weak}; +use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; @@ -40,41 +44,73 @@ pub use type_def::{TypeDef, TypeDefBound}; pub mod declarative; /// Extension Registries store extensions to be looked up e.g. during validation. -#[derive(Clone, Debug, Default, PartialEq)] -pub struct ExtensionRegistry(BTreeMap>); +#[derive(Debug, Display, Default)] +#[display("ExtensionRegistry[{}]", exts.keys().join(", "))] +pub struct ExtensionRegistry { + /// The extensions in the registry. + exts: BTreeMap>, + /// A flag indicating whether the current set of extensions has been + /// validated. + /// + /// This is used to avoid re-validating the extensions every time the + /// registry is validated, and is set to `false` whenever a new extension is + /// added. + valid: AtomicBool, +} + +impl PartialEq for ExtensionRegistry { + fn eq(&self, other: &Self) -> bool { + self.exts == other.exts + } +} + +impl Clone for ExtensionRegistry { + fn clone(&self) -> Self { + Self { + exts: self.exts.clone(), + valid: self.valid.load(Ordering::Relaxed).into(), + } + } +} impl ExtensionRegistry { + /// Create a new empty extension registry. + pub fn new(extensions: impl IntoIterator>) -> Self { + let mut res = Self::default(); + for ext in extensions.into_iter() { + res.register_updated(ext); + } + res + } + /// Gets the Extension with the given name pub fn get(&self, name: &str) -> Option<&Arc> { - self.0.get(name) + self.exts.get(name) } /// Returns `true` if the registry contains an extension with the given name. pub fn contains(&self, name: &str) -> bool { - self.0.contains_key(name) + self.exts.contains_key(name) } - /// Makes a new [ExtensionRegistry], validating all the extensions in it. - pub fn try_new( - value: impl IntoIterator>, - ) -> Result { - let mut res = ExtensionRegistry(BTreeMap::new()); - - for ext in value.into_iter() { - res.register(ext)?; + /// Validate the set of extensions, ensuring that each extension requirements are also in the registry. + /// + /// Note this potentially asks extensions to validate themselves against other extensions that + /// may *not* be valid themselves yet. It'd be better to order these respecting dependencies, + /// or at least to validate the types first - which we don't do at all yet: + // + // TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be + // cyclically dependent, so there is no perfect solution, and this is at least simple. + pub fn validate(&self) -> Result<(), ExtensionRegistryError> { + if self.valid.load(Ordering::Relaxed) { + return Ok(()); } - - // Note this potentially asks extensions to validate themselves against other extensions that - // may *not* be valid themselves yet. It'd be better to order these respecting dependencies, - // or at least to validate the types first - which we don't do at all yet: - // TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be - // cyclically dependent, so there is no perfect solution, and this is at least simple. - for ext in res.0.values() { - ext.validate(&res) + for ext in self.exts.values() { + ext.validate(self) .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?; } - - Ok(res) + self.valid.store(true, Ordering::Relaxed); + Ok(()) } /// Registers a new extension to the registry. @@ -85,7 +121,7 @@ impl ExtensionRegistry { extension: impl Into>, ) -> Result<(), ExtensionRegistryError> { let extension = extension.into(); - match self.0.entry(extension.name().clone()) { + match self.exts.entry(extension.name().clone()) { btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered( extension.name().clone(), prev.get().version().clone(), @@ -93,6 +129,9 @@ impl ExtensionRegistry { )), btree_map::Entry::Vacant(ve) => { ve.insert(extension); + // Clear the valid flag so that the registry is re-validated. + self.valid.store(false, Ordering::Relaxed); + Ok(()) } } @@ -109,7 +148,7 @@ impl ExtensionRegistry { /// see [`ExtensionRegistry::register_updated_ref`]. pub fn register_updated(&mut self, extension: impl Into>) { let extension = extension.into(); - match self.0.entry(extension.name().clone()) { + match self.exts.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { *prev.get_mut() = extension; @@ -119,6 +158,8 @@ impl ExtensionRegistry { ve.insert(extension); } } + // Clear the valid flag so that the registry is re-validated. + self.valid.store(false, Ordering::Relaxed); } /// Registers a new extension to the registry, keeping the one most up to @@ -131,7 +172,7 @@ impl ExtensionRegistry { /// Clones the Arc only when required. For no-cloning version see /// [`ExtensionRegistry::register_updated`]. pub fn register_updated_ref(&mut self, extension: &Arc) { - match self.0.entry(extension.name().clone()) { + match self.exts.entry(extension.name().clone()) { btree_map::Entry::Occupied(mut prev) => { if prev.get().version() < extension.version() { *prev.get_mut() = extension.clone(); @@ -141,31 +182,36 @@ impl ExtensionRegistry { ve.insert(extension.clone()); } } + // Clear the valid flag so that the registry is re-validated. + self.valid.store(false, Ordering::Relaxed); } /// Returns the number of extensions in the registry. pub fn len(&self) -> usize { - self.0.len() + self.exts.len() } /// Returns `true` if the registry contains no extensions. pub fn is_empty(&self) -> bool { - self.0.is_empty() + self.exts.is_empty() } /// Returns an iterator over the extensions in the registry. pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter { - self.0.values() + self.exts.values() } /// Returns an iterator over the extensions ids in the registry. pub fn ids(&self) -> impl Iterator { - self.0.keys() + self.exts.keys() } /// Delete an extension from the registry and return it if it was present. pub fn remove_extension(&mut self, name: &ExtensionId) -> Option> { - self.0.remove(name) + // Clear the valid flag so that the registry is re-validated. + self.valid.store(false, Ordering::Relaxed); + + self.exts.remove(name) } } @@ -175,7 +221,7 @@ impl IntoIterator for ExtensionRegistry { type IntoIter = std::collections::btree_map::IntoValues>; fn into_iter(self) -> Self::IntoIter { - self.0.into_values() + self.exts.into_values() } } @@ -185,7 +231,7 @@ impl<'a> IntoIterator for &'a ExtensionRegistry { type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc>; fn into_iter(self) -> Self::IntoIter { - self.0.values() + self.exts.values() } } @@ -205,8 +251,33 @@ impl Extend> for ExtensionRegistry { } } +// Encode/decode ExtensionRegistry as a list of extensions. +// We can get the map key from the extension itself. +impl<'de> Deserialize<'de> for ExtensionRegistry { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let extensions: Vec> = Vec::deserialize(deserializer)?; + Ok(ExtensionRegistry::new(extensions)) + } +} + +impl Serialize for ExtensionRegistry { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let extensions: Vec> = self.exts.values().cloned().collect(); + extensions.serialize(serializer) + } +} + /// An Extension Registry containing no extensions. -pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new()); +pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { + exts: BTreeMap::new(), + valid: AtomicBool::new(true), +}; /// An error that can occur in computing the signature of a node. /// TODO: decide on failure modes @@ -226,7 +297,7 @@ pub enum SignatureError { #[error("Invalid type arguments for operation")] InvalidTypeArgs, /// The Extension Registry did not contain an Extension referenced by the Signature - #[error("Extension '{missing}' not found. Available extensions: {}", + #[error("Extension '{missing}' is not part of the declared HUGR extensions [{}]", available.iter().map(|e| e.to_string()).collect::>().join(", ") )] ExtensionNotFound { @@ -614,7 +685,10 @@ pub enum ExtensionBuildError { } /// A set of extensions identified by their unique [`ExtensionId`]. -#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[derive( + Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, +)] +#[display("[{}]", _0.iter().join(", "))] pub struct ExtensionSet(BTreeSet); /// A special ExtensionId which indicates that the delta of a non-Function @@ -632,7 +706,7 @@ impl ExtensionSet { } /// Adds a extension to the set. - pub fn insert(&mut self, extension: &ExtensionId) { + pub fn insert(&mut self, extension: ExtensionId) { self.0.insert(extension.clone()); } @@ -660,7 +734,7 @@ impl ExtensionSet { } /// Create a extension set with a single element. - pub fn singleton(extension: &ExtensionId) -> Self { + pub fn singleton(extension: ExtensionId) -> Self { let mut set = Self::new(); set.insert(extension); set @@ -724,7 +798,25 @@ impl ExtensionSet { impl From for ExtensionSet { fn from(id: ExtensionId) -> Self { - Self::singleton(&id) + Self::singleton(id) + } +} + +impl IntoIterator for ExtensionSet { + type Item = ExtensionId; + type IntoIter = std::collections::btree_set::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a ExtensionSet { + type Item = &'a ExtensionId; + type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() } } @@ -738,12 +830,6 @@ fn as_typevar(e: &ExtensionId) -> Option { } } -impl Display for ExtensionSet { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_list().entries(self.0.iter()).finish() - } -} - impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { Self(BTreeSet::from_iter(iter)) @@ -783,8 +869,8 @@ pub mod test { fn test_register_update() { // Two registers that should remain the same. // We use them to test both `register_updated` and `register_updated_ref`. - let mut reg = ExtensionRegistry::try_new([]).unwrap(); - let mut reg_ref = ExtensionRegistry::try_new([]).unwrap(); + let mut reg = ExtensionRegistry::default(); + let mut reg_ref = ExtensionRegistry::default(); let ext_1_id = ExtensionId::new("ext1").unwrap(); let ext_2_id = ExtensionId::new("ext2").unwrap(); diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 1f6361b3e..2824aec80 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -127,7 +127,7 @@ impl ExtensionSetDeclaration { registry.register(PRELUDE.clone())?; } if !scope.contains(&PRELUDE_ID) { - scope.insert(&PRELUDE_ID); + scope.insert(PRELUDE_ID); } // Registers extensions sequentially, adding them to the current scope. @@ -137,7 +137,7 @@ impl ExtensionSetDeclaration { registry, }; let ext = decl.make_extension(&self.imports, ctx)?; - scope.insert(ext.name()); + scope.insert(ext.name().clone()); registry.register(ext)?; } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d9a3900fa..b060c7ae6 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -245,7 +245,7 @@ impl SignatureFunc { SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; let mut res = pf.instantiate(args, exts)?; - res.extension_reqs.insert(&def.extension); + res.extension_reqs.insert(def.extension.clone()); // If there are any row variables left, this will fail with an error: res.try_into() @@ -658,7 +658,8 @@ pub(super) mod test { Ok(()) })?; - let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap(); + let reg = ExtensionRegistry::new([PRELUDE.clone(), EXTENSION.clone(), ext]); + reg.validate()?; let e = reg.get(&EXT_ID).unwrap(); let list_usize = @@ -822,7 +823,7 @@ pub(super) mod test { )?; // Concrete extension set - let es = ExtensionSet::singleton(&EXT_ID); + let es = ExtensionSet::singleton(EXT_ID); let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); let args = [TypeArg::Extensions { es }]; diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 786b0379e..180f2dfc7 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -128,8 +128,7 @@ lazy_static! { }; /// An extension registry containing only the prelude - pub static ref PRELUDE_REGISTRY: ExtensionRegistry = - ExtensionRegistry::try_new([PRELUDE.clone()]).unwrap(); + pub static ref PRELUDE_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([PRELUDE.clone()]); } pub(crate) fn usize_custom_t(extension_ref: &Weak) -> CustomType { @@ -225,7 +224,7 @@ impl CustomConst for ConstString { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) } fn get_type(&self) -> Type { @@ -418,7 +417,7 @@ impl CustomConst for ConstUsize { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) } fn get_type(&self) -> Type { @@ -464,7 +463,7 @@ impl CustomConst for ConstError { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) } fn get_type(&self) -> Type { error_type() @@ -510,7 +509,7 @@ impl CustomConst for ConstExternalSymbol { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) } fn get_type(&self) -> Type { self.typ.clone() @@ -1022,7 +1021,7 @@ mod test { #[test] fn test_lift() { const XA: ExtensionId = ExtensionId::new_unchecked("xa"); - let op = Lift::new(type_row![Type::UNIT], ExtensionSet::singleton(&XA)); + let op = Lift::new(type_row![Type::UNIT], ExtensionSet::singleton(XA)); let optype: OpType = op.clone().into(); assert_eq!( optype.dataflow_signature().unwrap(), @@ -1102,7 +1101,7 @@ mod test { assert_eq!( error_val.extension_reqs(), - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) ); assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); @@ -1171,7 +1170,7 @@ mod test { assert!(string_const.validate().is_ok()); assert_eq!( string_const.extension_reqs(), - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) ); assert!(string_const.equal_consts(&ConstString::new("Lorem ipsum".into()))); assert!(!string_const.equal_consts(&ConstString::new("Lorem ispum".into()))); @@ -1198,7 +1197,7 @@ mod test { assert!(subject.validate().is_ok()); assert_eq!( subject.extension_reqs(), - ExtensionSet::singleton(&PRELUDE_ID) + ExtensionSet::singleton(PRELUDE_ID) ); assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index 41050f808..4ae3023a8 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -871,7 +871,7 @@ mod tests { #[test] fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(&PRELUDE_ID)); + let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(PRELUDE_ID)); let optype: OpType = op.clone().into(); let new_op: ArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -881,7 +881,7 @@ mod tests { fn test_repeat() { let size = 2; let element_ty = qb_t(); - let es = ExtensionSet::singleton(&PRELUDE_ID); + let es = ExtensionSet::singleton(PRELUDE_ID); let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); let optype: OpType = op.into(); @@ -907,7 +907,7 @@ mod tests { qb_t(), vec![usize_t()], 2, - ExtensionSet::singleton(&PRELUDE_ID), + ExtensionSet::singleton(PRELUDE_ID), ); let optype: OpType = op.clone().into(); let new_op: ArrayScan = optype.cast().unwrap(); @@ -919,7 +919,7 @@ mod tests { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); - let es = ExtensionSet::singleton(&PRELUDE_ID); + let es = ExtensionSet::singleton(PRELUDE_ID); let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); let optype: OpType = op.into(); @@ -947,7 +947,7 @@ mod tests { let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); - let es = ExtensionSet::singleton(&PRELUDE_ID); + let es = ExtensionSet::singleton(PRELUDE_ID); let op = ArrayScan::new( src_ty.clone(), diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 93e1b97f9..3ccaeb857 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -1,5 +1,8 @@ //! Utilities for resolving operations and types present in a HUGR, and updating -//! the list of used extensions. See [`crate::Hugr::resolve_extension_defs`]. +//! the list of used extensions. The functionalities of this module can be +//! called from the type methods [`crate::Hugr::resolve_extension_defs`], +//! [`crate::ops::OpType::used_extensions`], and +//! [`crate::types::Signature::used_extensions`]. //! //! When listing "used extensions" we only care about _definitional_ extension //! requirements, i.e., the operations and types that are required to define the @@ -18,16 +21,18 @@ mod ops; mod types; +mod types_mut; -pub(crate) use ops::update_op_extensions; -pub(crate) use types::update_op_types_extensions; +pub(crate) use ops::{collect_op_extension, resolve_op_extensions}; +pub(crate) use types::{collect_op_types_extensions, collect_signature_exts}; +pub(crate) use types_mut::resolve_op_types_extensions; use derive_more::{Display, Error, From}; use super::{Extension, ExtensionId, ExtensionRegistry}; use crate::ops::custom::OpaqueOpError; use crate::ops::{NamedOp, OpName, OpType}; -use crate::types::TypeName; +use crate::types::{FuncTypeBase, MaybeRV, TypeName}; use crate::Node; /// Errors that can occur during extension resolution. @@ -101,3 +106,63 @@ impl ExtensionResolutionError { } } } + +/// Errors that can occur when collecting extension requirements. +#[derive(Debug, Display, Clone, Error, From, PartialEq)] +#[non_exhaustive] +pub enum ExtensionCollectionError { + /// An operation requires an extension that is not in the given registry. + #[display( + "{op}{} contains custom types for which have lost the reference to their defining extensions. Dropped extensions: {}", + if let Some(node) = node { format!(" ({})", node) } else { "".to_string() }, + missing_extensions.join(", ") + )] + DroppedOpExtensions { + /// The node that is missing extensions. + node: Option, + /// The operation that is missing extensions. + op: OpName, + /// The missing extensions. + missing_extensions: Vec, + }, + /// A signature requires an extension that is not in the given registry. + #[display( + "Signature {signature} contains custom types for which have lost the reference to their defining extensions. Dropped extensions: {}", + missing_extensions.join(", ") + )] + DroppedSignatureExtensions { + /// The signature that is missing extensions. + signature: String, + /// The missing extensions. + missing_extensions: Vec, + }, +} + +impl ExtensionCollectionError { + /// Create a new error when operation extensions have been dropped. + pub fn dropped_op_extension( + node: Option, + op: &OpType, + missing_extension: impl IntoIterator, + ) -> Self { + Self::DroppedOpExtensions { + node, + op: NamedOp::name(op), + missing_extensions: missing_extension.into_iter().collect(), + } + } + + /// Create a new error when signature extensions have been dropped. + pub fn dropped_signature( + signature: &FuncTypeBase, + missing_extension: impl IntoIterator, + ) -> Self { + Self::DroppedSignatureExtensions { + signature: format!("{signature}"), + missing_extensions: missing_extension.into_iter().collect(), + } + } +} + +#[cfg(test)] +mod test; diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index 7c8fbfc37..42e3954dd 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -1,12 +1,50 @@ -//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required extension. +//! Resolve `OpaqueOp`s into `ExtensionOp`s and return an operation's required +//! extension. +//! +//! Contains both mutable ([`resolve_op_extensions`]) and immutable +//! ([`collect_operation_extension`]) methods to resolve operations and collect +//! the required extensions respectively. use std::sync::Arc; -use super::{Extension, ExtensionRegistry, ExtensionResolutionError}; +use super::{Extension, ExtensionCollectionError, ExtensionRegistry, ExtensionResolutionError}; use crate::ops::custom::OpaqueOpError; use crate::ops::{DataflowOpTrait, ExtensionOp, NamedOp, OpType}; use crate::Node; +/// Returns the extension in the registry required by the operation. +/// +/// If the operation does not require an extension, returns `None`. +/// +/// [`ExtensionOp`]s store a [`Weak`] reference to their extension, which can be +/// invalidated if the original `Arc` is dropped. On such cases, we +/// return an error with the missing extension names. +/// +/// # Parameters +/// +/// - `node`: The node where the operation is located, if available. This is +/// used to provide context in the error message. +/// - `op`: The operation to collect the extensions from. +pub(crate) fn collect_op_extension( + node: Option, + op: &OpType, +) -> Result>, ExtensionCollectionError> { + let OpType::ExtensionOp(ext_op) = op else { + // TODO: Extract the extension when the operation is a `Const`. + // https://github.com/CQCL/hugr/issues/1742 + return Ok(None); + }; + let ext = ext_op.def().extension(); + match ext.upgrade() { + Some(e) => Ok(Some(e)), + None => Err(ExtensionCollectionError::dropped_op_extension( + node, + op, + [ext_op.def().extension_id().clone()], + )), + } +} + /// Compute the required extension for an operation. /// /// If the op is a [`OpType::OpaqueOp`], replace it with a resolved @@ -20,7 +58,7 @@ use crate::Node; /// /// If the serialized opaque resolves to a definition that conflicts with what /// was serialized. Or if the operation is not found in the registry. -pub(crate) fn update_op_extensions<'e>( +pub(crate) fn resolve_op_extensions<'e>( node: Node, op: &mut OpType, extensions: &'e ExtensionRegistry, diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs new file mode 100644 index 000000000..7d02d1e0c --- /dev/null +++ b/hugr-core/src/extension/resolution/test.rs @@ -0,0 +1,290 @@ +//! Tests for extension resolution. + +use core::panic; +use std::sync::Arc; + +use cool_asserts::assert_matches; +use itertools::Itertools; +use rstest::rstest; + +use crate::builder::{ + Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, +}; +use crate::extension::prelude::{bool_t, ConstUsize}; +use crate::extension::resolution::{ + resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, +}; +use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE}; +use crate::ops::{CallIndirect, ExtensionOp, Input, OpTrait, OpType, Tag, Value}; +use crate::std_extensions::arithmetic::float_types::{self, float64_type}; +use crate::std_extensions::arithmetic::int_ops; +use crate::std_extensions::arithmetic::int_types::{self, int_type}; +use crate::types::{Signature, Type}; +use crate::{type_row, Extension, Hugr, HugrView}; + +#[rstest] +#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] +// A type with extra extensions in its instantiated type arguments. +#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4), + ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()] +))] +fn collect_type_extensions(#[case] op: impl Into, #[case] extensions: ExtensionRegistry) { + let op = op.into(); + let resolved = op.used_extensions().unwrap(); + assert_eq!(resolved, extensions); +} + +#[rstest] +#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] +// A type with extra extensions in its instantiated type arguments. +#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4), + ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned()] +))] +fn resolve_type_extensions(#[case] op: impl Into, #[case] extensions: ExtensionRegistry) { + let op = op.into(); + + // Ensure that all the `Weak` pointers get invalidated by round-tripping via serialization. + let ser = serde_json::to_string(&op).unwrap(); + let mut deser_op: OpType = serde_json::from_str(&ser).unwrap(); + + let dummy_node = portgraph::NodeIndex::new(0).into(); + + let mut used_exts = ExtensionRegistry::default(); + resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap(); + resolve_op_types_extensions(dummy_node, &mut deser_op, &extensions, &mut used_exts).unwrap(); + + let deser_extensions = deser_op.used_extensions().unwrap(); + + assert_eq!( + deser_extensions, extensions, + "{deser_extensions} != {extensions}" + ); +} + +/// Create a new test extension with a single operation. +/// +/// Returns an instance of the defined op. +fn make_extension(name: &str, op_name: &str) -> (Arc, OpType) { + let ext = Extension::new_test_arc(ExtensionId::new_unchecked(name), |ext, extension_ref| { + ext.add_op( + op_name.into(), + "".to_string(), + Signature::new_endo(vec![bool_t()]), + extension_ref, + ) + .unwrap(); + }); + let op_def = ext.get_op(op_name).unwrap(); + let op = ExtensionOp::new(op_def.clone(), vec![], &ExtensionRegistry::default()).unwrap(); + (ext, op.into()) +} + +/// Build a hugr with all possible op nodes and resolve the extensions. +#[rstest] +fn resolve_hugr_extensions() { + let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let (ext_b, op_b) = make_extension("dummy.b", "op_b"); + let (ext_c, op_c) = make_extension("dummy.c", "op_c"); + let (ext_d, op_d) = make_extension("dummy.d", "op_d"); + let (ext_e, op_e) = make_extension("dummy.e", "op_e"); + + let build_extensions = ExtensionRegistry::new([ + PRELUDE.to_owned(), + ext_a.clone(), + ext_b.clone(), + ext_c.clone(), + ext_d.clone(), + ext_e.clone(), + float_types::EXTENSION.to_owned(), + int_types::EXTENSION.to_owned(), + ]); + + let mut module = ModuleBuilder::new(); + + // A constant op using the prelude extension. + module.add_constant(Value::extension(ConstUsize::new(42))); + + // A function declaration using the floats extension in its signature. + let decl = module + .declare( + "dummy_declaration", + Signature::new_endo(vec![float64_type()]).into(), + ) + .unwrap(); + + // A function definition using the int_types and float_types extension in its body. + let mut func = module + .define_function( + "dummy_fn", + Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( + [ + ext_a.name(), + ext_b.name(), + ext_c.name(), + ext_d.name(), + ext_e.name(), + ] + .into_iter() + .cloned() + .collect::(), + ), + ) + .unwrap(); + let [func_i0, func_i1] = func.input_wires_arr(); + + // Call the function declaration directly, and load & call indirectly. + func.call( + &decl, + &[], + vec![func_i0], + &ExtensionRegistry::new([float_types::EXTENSION.to_owned()]), + ) + .unwrap(); + let loaded_func = func + .load_func( + &decl, + &[], + &ExtensionRegistry::new([float_types::EXTENSION.to_owned()]), + ) + .unwrap(); + func.add_dataflow_op( + CallIndirect { + signature: Signature::new_endo(vec![float64_type()]), + }, + vec![loaded_func, func_i0], + ) + .unwrap(); + + // Add one of the custom ops. + func.add_dataflow_op(op_a, vec![func_i1]).unwrap(); + + // A nested dataflow region. + let mut dfg = func.dfg_builder_endo([(bool_t(), func_i1)]).unwrap(); + let dfg_inputs = dfg.input_wires().collect_vec(); + dfg.add_dataflow_op(op_b, dfg_inputs.clone()).unwrap(); + dfg.finish_with_outputs(dfg_inputs).unwrap(); + + // A tag + func.add_dataflow_op( + Tag::new(0, vec![vec![bool_t()].into(), vec![int_type(4)].into()]), + vec![func_i1], + ) + .unwrap(); + + // Dfg control flow: Tail loop + let mut tail_loop = func + .tail_loop_builder([(bool_t(), func_i1)], [], vec![].into()) + .unwrap(); + let tl_inputs = tail_loop.input_wires().collect_vec(); + tail_loop.add_dataflow_op(op_c, tl_inputs).unwrap(); + let tl_tag = tail_loop.add_load_const(Value::true_val()); + let tl_tag = tail_loop + .add_dataflow_op( + Tag::new(0, vec![vec![Type::new_unit_sum(2)].into(), vec![].into()]), + vec![tl_tag], + ) + .unwrap() + .out_wire(0); + tail_loop.finish_with_outputs(tl_tag, vec![]).unwrap(); + + // Dfg control flow: Conditionals + let cond_tag = func.add_load_const(Value::unary_unit_sum()); + let mut cond = func + .conditional_builder(([type_row![]], cond_tag), [], type_row![]) + .unwrap(); + let mut case = cond.case_builder(0).unwrap(); + case.add_dataflow_op(op_e, [func_i1]).unwrap(); + case.finish_with_outputs([]).unwrap(); + + // Cfg control flow. + let mut cfg = func + .cfg_builder([(bool_t(), func_i1)], vec![].into()) + .unwrap(); + let mut cfg_entry = cfg.entry_builder([type_row![]], type_row![]).unwrap(); + let [cfg_i0] = cfg_entry.input_wires_arr(); + cfg_entry.add_dataflow_op(op_d, [cfg_i0]).unwrap(); + let cfg_tag = cfg_entry.add_load_const(Value::unary_unit_sum()); + let cfg_entry_wire = cfg_entry.finish_with_outputs(cfg_tag, []).unwrap(); + let cfg_exit = cfg.exit_block(); + cfg.branch(&cfg_entry_wire, 0, &cfg_exit).unwrap(); + + // -------------------------------------------------- + + // Finally, finish the hugr and ensure it's using the right extensions. + func.finish_with_outputs(vec![]).unwrap(); + let mut hugr = module + .finish_hugr(&build_extensions) + .unwrap_or_else(|e| panic!("{e}")); + + // Check that the read-only methods collect the same extensions. + let mut collected_exts = ExtensionRegistry::default(); + for node in hugr.nodes() { + let op = hugr.get_optype(node); + collected_exts.extend(op.used_extensions().unwrap()); + } + assert_eq!( + collected_exts, build_extensions, + "{collected_exts} != {build_extensions}" + ); + + // Check that the mutable methods collect the same extensions. + assert_matches!( + hugr.resolve_extension_defs(&ExtensionRegistry::default()), + Err(_) + ); + let resolved = hugr.resolve_extension_defs(&build_extensions).unwrap(); + assert_eq!( + &resolved, &build_extensions, + "{resolved} != {build_extensions}" + ); +} + +/// Fail when collecting extensions but the weak pointers are not resolved. +#[rstest] +fn dropped_weak_extensions() { + let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let build_extensions = ExtensionRegistry::new([ + PRELUDE.to_owned(), + ext_a.clone(), + float_types::EXTENSION.to_owned(), + ]); + + let mut func = FunctionBuilder::new( + "dummy_fn", + Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( + [ext_a.name()] + .into_iter() + .cloned() + .collect::(), + ), + ) + .unwrap(); + let [_func_i0, func_i1] = func.input_wires_arr(); + func.add_dataflow_op(op_a, vec![func_i1]).unwrap(); + + let hugr = func.finish_hugr(&build_extensions).unwrap(); + + // Do a serialization roundtrip to drop the references. + let ser = serde_json::to_string(&hugr).unwrap(); + let hugr: Hugr = serde_json::from_str(&ser).unwrap(); + + let op_collection = hugr + .nodes() + .try_for_each(|node| hugr.get_optype(node).used_extensions().map(|_| ())); + assert_matches!( + op_collection, + Err(ExtensionCollectionError::DroppedOpExtensions { .. }) + ); + + let op_collection = hugr.nodes().try_for_each(|node| { + let op = hugr.get_optype(node); + if let Some(sig) = op.dataflow_signature() { + sig.used_extensions()?; + } + Ok(()) + }); + assert_matches!( + op_collection, + Err(ExtensionCollectionError::DroppedSignatureExtensions { .. }) + ); +} diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index b249a4dca..128900e25 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -1,174 +1,223 @@ -//! Resolve weak links inside `CustomType`s in an optype's signature. +//! Collect the extensions referenced inside `CustomType`s in an optype or +//! signature. +//! +//! Fails if any of the weak extension pointers have been invalidated. +//! +//! See [`super::resolve_op_types_extensions`] for a mutating version that +//! updates the weak links to point to the correct extensions. -use std::sync::Arc; - -use super::{ExtensionRegistry, ExtensionResolutionError}; -use crate::ops::OpType; +use super::ExtensionCollectionError; +use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::ops::{DataflowOpTrait, OpType}; use crate::types::type_row::TypeRowBase; -use crate::types::{MaybeRV, Signature, SumType, TypeBase, TypeEnum}; +use crate::types::{FuncTypeBase, MaybeRV, SumType, TypeArg, TypeBase, TypeEnum}; use crate::Node; -/// Replace the dangling extension pointer in the [`CustomType`]s inside a -/// signature with a valid pointer to the extension in the `extensions` -/// registry. +/// Collects every extension used te define the types in an operation. /// -/// When a pointer is replaced, the extension is added to the -/// `used_extensions` registry and the new type definition is returned. +/// Custom types store a [`Weak`] reference to their extension, which can be +/// invalidated if the original `Arc` is dropped. This normally +/// happens when deserializing a HUGR. On such cases, we return an error with +/// the missing extension names. /// -/// This is a helper function used right after deserializing a Hugr. -pub fn update_op_types_extensions( - node: Node, - op: &mut OpType, - extensions: &ExtensionRegistry, - used_extensions: &mut ExtensionRegistry, -) -> Result<(), ExtensionResolutionError> { +/// Use [`super::resolve_op_types_extensions`] instead to update the weak references and +/// ensure they point to valid extensions. +/// +/// # Attributes +/// +/// - `node`: The node where the operation is located, if available. +/// This is used to provide context in the error message. +/// - `op`: The operation to collect the extensions from. +pub(crate) fn collect_op_types_extensions( + node: Option, + op: &OpType, +) -> Result { + let mut used = ExtensionRegistry::default(); + let mut missing = ExtensionSet::new(); + match op { OpType::ExtensionOp(ext) => { - update_signature_exts(node, ext.signature_mut(), extensions, used_extensions)? - } - OpType::FuncDefn(f) => { - update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? - } - OpType::FuncDecl(f) => { - update_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? - } - OpType::Const(_c) => { - // TODO: Is it OK to assume that `Value::get_type` returns a well-resolved value? - } - OpType::Input(inp) => { - update_type_row_exts(node, &mut inp.types, extensions, used_extensions)? + for arg in ext.args() { + collect_typearg_exts(arg, &mut used, &mut missing); + } + collect_signature_exts(&ext.signature(), &mut used, &mut missing) } - OpType::Output(out) => { - update_type_row_exts(node, &mut out.types, extensions, used_extensions)? + OpType::FuncDefn(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing), + OpType::FuncDecl(f) => collect_signature_exts(f.signature.body(), &mut used, &mut missing), + OpType::Const(c) => { + let typ = c.get_type(); + collect_type_exts(&typ, &mut used, &mut missing); } + OpType::Input(inp) => collect_type_row_exts(&inp.types, &mut used, &mut missing), + OpType::Output(out) => collect_type_row_exts(&out.types, &mut used, &mut missing), OpType::Call(c) => { - update_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; - update_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; - } - OpType::CallIndirect(c) => { - update_signature_exts(node, &mut c.signature, extensions, used_extensions)? - } - OpType::LoadConstant(lc) => { - update_type_exts(node, &mut lc.datatype, extensions, used_extensions)? + collect_signature_exts(c.func_sig.body(), &mut used, &mut missing); + collect_signature_exts(&c.instantiation, &mut used, &mut missing); } + OpType::CallIndirect(c) => collect_signature_exts(&c.signature, &mut used, &mut missing), + OpType::LoadConstant(lc) => collect_type_exts(&lc.datatype, &mut used, &mut missing), OpType::LoadFunction(lf) => { - update_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; - update_signature_exts(node, &mut lf.signature, extensions, used_extensions)?; - } - OpType::DFG(dfg) => { - update_signature_exts(node, &mut dfg.signature, extensions, used_extensions)? + collect_signature_exts(lf.func_sig.body(), &mut used, &mut missing); + collect_signature_exts(&lf.signature, &mut used, &mut missing); } + OpType::DFG(dfg) => collect_signature_exts(&dfg.signature, &mut used, &mut missing), OpType::OpaqueOp(op) => { - update_signature_exts(node, op.signature_mut(), extensions, used_extensions)? + for arg in op.args() { + collect_typearg_exts(arg, &mut used, &mut missing); + } + collect_signature_exts(&op.signature(), &mut used, &mut missing) } OpType::Tag(t) => { - for variant in t.variants.iter_mut() { - update_type_row_exts(node, variant, extensions, used_extensions)? + for variant in t.variants.iter() { + collect_type_row_exts(variant, &mut used, &mut missing) } } OpType::DataflowBlock(db) => { - update_type_row_exts(node, &mut db.inputs, extensions, used_extensions)?; - update_type_row_exts(node, &mut db.other_outputs, extensions, used_extensions)?; - for row in db.sum_rows.iter_mut() { - update_type_row_exts(node, row, extensions, used_extensions)?; + collect_type_row_exts(&db.inputs, &mut used, &mut missing); + collect_type_row_exts(&db.other_outputs, &mut used, &mut missing); + for row in db.sum_rows.iter() { + collect_type_row_exts(row, &mut used, &mut missing); } } OpType::ExitBlock(e) => { - update_type_row_exts(node, &mut e.cfg_outputs, extensions, used_extensions)?; + collect_type_row_exts(&e.cfg_outputs, &mut used, &mut missing); } OpType::TailLoop(tl) => { - update_type_row_exts(node, &mut tl.just_inputs, extensions, used_extensions)?; - update_type_row_exts(node, &mut tl.just_outputs, extensions, used_extensions)?; - update_type_row_exts(node, &mut tl.rest, extensions, used_extensions)?; + collect_type_row_exts(&tl.just_inputs, &mut used, &mut missing); + collect_type_row_exts(&tl.just_outputs, &mut used, &mut missing); + collect_type_row_exts(&tl.rest, &mut used, &mut missing); } OpType::CFG(cfg) => { - update_signature_exts(node, &mut cfg.signature, extensions, used_extensions)?; + collect_signature_exts(&cfg.signature, &mut used, &mut missing); } OpType::Conditional(cond) => { - for row in cond.sum_rows.iter_mut() { - update_type_row_exts(node, row, extensions, used_extensions)?; + for row in cond.sum_rows.iter() { + collect_type_row_exts(row, &mut used, &mut missing); } - update_type_row_exts(node, &mut cond.other_inputs, extensions, used_extensions)?; - update_type_row_exts(node, &mut cond.outputs, extensions, used_extensions)?; + collect_type_row_exts(&cond.other_inputs, &mut used, &mut missing); + collect_type_row_exts(&cond.outputs, &mut used, &mut missing); } OpType::Case(case) => { - update_signature_exts(node, &mut case.signature, extensions, used_extensions)?; + collect_signature_exts(&case.signature, &mut used, &mut missing); } // Ignore optypes that do not store a signature. - _ => {} - } - Ok(()) + OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {} + }; + + missing + .is_empty() + .then_some(used) + .ok_or(ExtensionCollectionError::dropped_op_extension( + node, op, missing, + )) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// Collect the Extension pointers in the [`CustomType`]s inside a signature. +/// +/// # Attributes /// -/// Adds the extensions used in the signature to the `used_extensions` registry. -fn update_signature_exts( - node: Node, - signature: &mut Signature, - extensions: &ExtensionRegistry, +/// - `signature`: The signature to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +pub(crate) fn collect_signature_exts( + signature: &FuncTypeBase, used_extensions: &mut ExtensionRegistry, -) -> Result<(), ExtensionResolutionError> { + missing_extensions: &mut ExtensionSet, +) { // Note that we do not include the signature's `extension_reqs` here, as those refer - // to _runtime_ requirements that may not be currently present. + // to _runtime_ requirements that we do not be require to be defined. + // // See https://github.com/CQCL/hugr/issues/1734 // TODO: Update comment once that issue gets implemented. - update_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; - update_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; - Ok(()) + collect_type_row_exts(&signature.input, used_extensions, missing_extensions); + collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. +/// Collect the Extension pointers in the [`CustomType`]s inside a type row. /// -/// Adds the extensions used in the row to the `used_extensions` registry. -fn update_type_row_exts( - node: Node, - row: &mut TypeRowBase, - extensions: &ExtensionRegistry, +/// # Attributes +/// +/// - `row`: The type row to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +fn collect_type_row_exts( + row: &TypeRowBase, used_extensions: &mut ExtensionRegistry, -) -> Result<(), ExtensionResolutionError> { - for ty in row.iter_mut() { - update_type_exts(node, ty, extensions, used_extensions)?; + missing_extensions: &mut ExtensionSet, +) { + for ty in row.iter() { + collect_type_exts(ty, used_extensions, missing_extensions); } - Ok(()) } -/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// Collect the Extension pointers in the [`CustomType`]s inside a type. +/// +/// # Attributes /// -/// Adds the extensions used in the type to the `used_extensions` registry. -fn update_type_exts( - node: Node, - typ: &mut TypeBase, - extensions: &ExtensionRegistry, +/// - `typ`: The type to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +pub(super) fn collect_type_exts( + typ: &TypeBase, used_extensions: &mut ExtensionRegistry, -) -> Result<(), ExtensionResolutionError> { - match typ.as_type_enum_mut() { + missing_extensions: &mut ExtensionSet, +) { + match typ.as_type_enum() { TypeEnum::Extension(custom) => { - let ext_id = custom.extension(); - let ext = extensions.get(ext_id).ok_or_else(|| { - ExtensionResolutionError::missing_type_extension( - node, - custom.name(), - ext_id, - extensions, - ) - })?; - - // Add the extension to the used extensions registry, - // and update the CustomType with the valid pointer. - used_extensions.register_updated_ref(ext); - custom.update_extension(Arc::downgrade(ext)); + for arg in custom.args() { + collect_typearg_exts(arg, used_extensions, missing_extensions); + } + match custom.extension_ref().upgrade() { + Some(ext) => { + used_extensions.register_updated(ext); + } + None => { + missing_extensions.insert(custom.extension().clone()); + } + } } TypeEnum::Function(f) => { - update_type_row_exts(node, &mut f.input, extensions, used_extensions)?; - update_type_row_exts(node, &mut f.output, extensions, used_extensions)?; + collect_type_row_exts(&f.input, used_extensions, missing_extensions); + collect_type_row_exts(&f.output, used_extensions, missing_extensions); } TypeEnum::Sum(SumType::General { rows }) => { - for row in rows.iter_mut() { - update_type_row_exts(node, row, extensions, used_extensions)?; + for row in rows.iter() { + collect_type_row_exts(row, used_extensions, missing_extensions); + } + } + // Other types do not store extensions. + TypeEnum::Alias(_) + | TypeEnum::RowVar(_) + | TypeEnum::Variable(_, _) + | TypeEnum::Sum(SumType::Unit { .. }) => {} + } +} + +/// Collect the Extension pointers in the [`CustomType`]s inside a type argument. +/// +/// # Attributes +/// +/// - `arg`: The type argument to collect the extensions from. +/// - `used_extensions`: A The registry where to store the used extensions. +/// - `missing_extensions`: A set of `ExtensionId`s of which the +/// `Weak` pointer has been invalidated. +fn collect_typearg_exts( + arg: &TypeArg, + used_extensions: &mut ExtensionRegistry, + missing_extensions: &mut ExtensionSet, +) { + match arg { + TypeArg::Type { ty } => collect_type_exts(ty, used_extensions, missing_extensions), + TypeArg::Sequence { elems } => { + for elem in elems.iter() { + collect_typearg_exts(elem, used_extensions, missing_extensions); } } + // We ignore the `TypeArg::Extension` case, as it is not required to + // **define** the hugr. _ => {} } - Ok(()) } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs new file mode 100644 index 000000000..7f8f16f4c --- /dev/null +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -0,0 +1,222 @@ +//! Resolve weak links inside `CustomType`s in an optype's signature, while +//! collecting all used extensions. +//! +//! For a non-mutating option see [`super::collect_op_types_extensions`]. + +use std::sync::Arc; + +use super::types::collect_type_exts; +use super::{ExtensionRegistry, ExtensionResolutionError}; +use crate::extension::ExtensionSet; +use crate::ops::OpType; +use crate::types::type_row::TypeRowBase; +use crate::types::{MaybeRV, Signature, SumType, TypeArg, TypeBase, TypeEnum}; +use crate::Node; + +/// Replace the dangling extension pointer in the [`CustomType`]s inside an +/// optype with a valid pointer to the extension in the `extensions` +/// registry. +/// +/// When a pointer is replaced, the extension is added to the +/// `used_extensions` registry. +/// +/// This is a helper function used right after deserializing a Hugr. +pub fn resolve_op_types_extensions( + node: Node, + op: &mut OpType, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match op { + OpType::ExtensionOp(ext) => { + for arg in ext.args_mut() { + resolve_typearg_exts(node, arg, extensions, used_extensions)?; + } + resolve_signature_exts(node, ext.signature_mut(), extensions, used_extensions)? + } + OpType::FuncDefn(f) => { + resolve_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::FuncDecl(f) => { + resolve_signature_exts(node, f.signature.body_mut(), extensions, used_extensions)? + } + OpType::Const(c) => { + let typ = c.get_type(); + let mut missing = ExtensionSet::new(); + collect_type_exts(&typ, used_extensions, &mut missing); + // We expect that the `CustomConst::get_type` binary calls always return valid extensions. + // As we cannot update the `CustomConst` type, we ignore the result. + // + // Some exotic consts may need https://github.com/CQCL/hugr/issues/1742 to be implemented + // to pass this test. + //assert!(missing.is_empty()); + } + OpType::Input(inp) => { + resolve_type_row_exts(node, &mut inp.types, extensions, used_extensions)? + } + OpType::Output(out) => { + resolve_type_row_exts(node, &mut out.types, extensions, used_extensions)? + } + OpType::Call(c) => { + resolve_signature_exts(node, c.func_sig.body_mut(), extensions, used_extensions)?; + resolve_signature_exts(node, &mut c.instantiation, extensions, used_extensions)?; + } + OpType::CallIndirect(c) => { + resolve_signature_exts(node, &mut c.signature, extensions, used_extensions)? + } + OpType::LoadConstant(lc) => { + resolve_type_exts(node, &mut lc.datatype, extensions, used_extensions)? + } + OpType::LoadFunction(lf) => { + resolve_signature_exts(node, lf.func_sig.body_mut(), extensions, used_extensions)?; + resolve_signature_exts(node, &mut lf.signature, extensions, used_extensions)?; + } + OpType::DFG(dfg) => { + resolve_signature_exts(node, &mut dfg.signature, extensions, used_extensions)? + } + OpType::OpaqueOp(op) => { + for arg in op.args_mut() { + resolve_typearg_exts(node, arg, extensions, used_extensions)?; + } + resolve_signature_exts(node, op.signature_mut(), extensions, used_extensions)? + } + OpType::Tag(t) => { + for variant in t.variants.iter_mut() { + resolve_type_row_exts(node, variant, extensions, used_extensions)? + } + } + OpType::DataflowBlock(db) => { + resolve_type_row_exts(node, &mut db.inputs, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut db.other_outputs, extensions, used_extensions)?; + for row in db.sum_rows.iter_mut() { + resolve_type_row_exts(node, row, extensions, used_extensions)?; + } + } + OpType::ExitBlock(e) => { + resolve_type_row_exts(node, &mut e.cfg_outputs, extensions, used_extensions)?; + } + OpType::TailLoop(tl) => { + resolve_type_row_exts(node, &mut tl.just_inputs, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut tl.just_outputs, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut tl.rest, extensions, used_extensions)?; + } + OpType::CFG(cfg) => { + resolve_signature_exts(node, &mut cfg.signature, extensions, used_extensions)?; + } + OpType::Conditional(cond) => { + for row in cond.sum_rows.iter_mut() { + resolve_type_row_exts(node, row, extensions, used_extensions)?; + } + resolve_type_row_exts(node, &mut cond.other_inputs, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut cond.outputs, extensions, used_extensions)?; + } + OpType::Case(case) => { + resolve_signature_exts(node, &mut case.signature, extensions, used_extensions)?; + } + // Ignore optypes that do not store a signature. + OpType::Module(_) | OpType::AliasDecl(_) | OpType::AliasDefn(_) => {} + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a signature. +/// +/// Adds the extensions used in the signature to the `used_extensions` registry. +fn resolve_signature_exts( + node: Node, + signature: &mut Signature, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + // Note that we do not include the signature's `extension_reqs` here, as those refer + // to _runtime_ requirements that may not be currently present. + // See https://github.com/CQCL/hugr/issues/1734 + // TODO: Update comment once that issue gets implemented. + resolve_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type row. +/// +/// Adds the extensions used in the row to the `used_extensions` registry. +fn resolve_type_row_exts( + node: Node, + row: &mut TypeRowBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + for ty in row.iter_mut() { + resolve_type_exts(node, ty, extensions, used_extensions)?; + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type. +/// +/// Adds the extensions used in the type to the `used_extensions` registry. +fn resolve_type_exts( + node: Node, + typ: &mut TypeBase, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match typ.as_type_enum_mut() { + TypeEnum::Extension(custom) => { + for arg in custom.args_mut() { + resolve_typearg_exts(node, arg, extensions, used_extensions)?; + } + + let ext_id = custom.extension(); + let ext = extensions.get(ext_id).ok_or_else(|| { + ExtensionResolutionError::missing_type_extension( + node, + custom.name(), + ext_id, + extensions, + ) + })?; + + // Add the extension to the used extensions registry, + // and update the CustomType with the valid pointer. + used_extensions.register_updated_ref(ext); + custom.update_extension(Arc::downgrade(ext)); + } + TypeEnum::Function(f) => { + resolve_type_row_exts(node, &mut f.input, extensions, used_extensions)?; + resolve_type_row_exts(node, &mut f.output, extensions, used_extensions)?; + } + TypeEnum::Sum(SumType::General { rows }) => { + for row in rows.iter_mut() { + resolve_type_row_exts(node, row, extensions, used_extensions)?; + } + } + // Other types do not store extensions. + TypeEnum::Alias(_) + | TypeEnum::RowVar(_) + | TypeEnum::Variable(_, _) + | TypeEnum::Sum(SumType::Unit { .. }) => {} + } + Ok(()) +} + +/// Update all weak Extension pointers in the [`CustomType`]s inside a type arg. +/// +/// Adds the extensions used in the type to the `used_extensions` registry. +fn resolve_typearg_exts( + node: Node, + arg: &mut TypeArg, + extensions: &ExtensionRegistry, + used_extensions: &mut ExtensionRegistry, +) -> Result<(), ExtensionResolutionError> { + match arg { + TypeArg::Type { ty } => resolve_type_exts(node, ty, extensions, used_extensions)?, + TypeArg::Sequence { elems } => { + for elem in elems.iter_mut() { + resolve_typearg_exts(node, elem, extensions, used_extensions)?; + } + } + _ => {} + } + Ok(()) +} diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index d48f596ea..0681d5818 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -358,8 +358,7 @@ mod test { .unwrap(); }) }; - static ref DUMMY_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXT.clone()]).unwrap(); + static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::new([EXT.clone()]); } impl MakeRegisteredOp for DummyEnum { fn extension_id(&self) -> ExtensionId { diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index b42622745..c3412a33b 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -19,13 +19,13 @@ pub use ident::{IdentList, InvalidIdentifier}; pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; -use portgraph::{Hierarchy, PortMut, UnmanagedDenseMap}; +use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; use thiserror::Error; pub use self::views::{HugrView, RootTagged}; use crate::core::NodeIndex; use crate::extension::resolution::{ - update_op_extensions, update_op_types_extensions, ExtensionResolutionError, + resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, }; use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; use crate::ops::{OpTag, OpTrait}; @@ -162,7 +162,7 @@ impl Hugr { return Ok(es.clone()); // Can't neither add nor remove, so nothing to do } let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - *es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged); + *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged); Ok(es.clone()) } @@ -213,7 +213,11 @@ impl Hugr { // // This is not something we want to expose it the API, so we manually // iterate instead of writing it as a method. - for n in 0..self.node_count() { + // + // Since we don't have a non-borrowing iterator over all the possible + // NodeIds, we have to simulate it by iterating over all possible + // indices and checking if the node exists. + for n in 0..self.graph.node_capacity() { let pg_node = portgraph::NodeIndex::new(n); let node: Node = pg_node.into(); if !self.contains_node(node) { @@ -222,10 +226,10 @@ impl Hugr { let op = &mut self.op_types[pg_node]; - if let Some(extension) = update_op_extensions(node, op, extensions)? { + if let Some(extension) = resolve_op_extensions(node, op, extensions)? { used_extensions.register_updated_ref(extension); } - update_op_types_extensions(node, op, extensions, &mut used_extensions)?; + resolve_op_types_extensions(node, op, extensions, &mut used_extensions)?; } Ok(used_extensions) diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index 6981c7277..74327f970 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -169,12 +169,12 @@ mod test { fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { use crate::extension::prelude::Lift; - let reg = ExtensionRegistry::try_new([ + let reg = ExtensionRegistry::new([ PRELUDE.to_owned(), int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned(), - ]) - .unwrap(); + ]); + reg.validate()?; let int_ty = &int_types::INT_TYPES[6]; let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; @@ -256,12 +256,12 @@ mod test { }; let [q, p] = swap.outputs_arr(); let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?; - let reg = ExtensionRegistry::try_new([ + let reg = ExtensionRegistry::new([ test_quantum_extension::EXTENSION.clone(), PRELUDE.clone(), float_types::EXTENSION.clone(), - ]) - .unwrap(); + ]); + reg.validate()?; let mut h = h.finish_hugr_with_outputs(cx.outputs(), ®)?; assert_eq!(find_dfgs(&h), vec![h.root(), swap.node()]); @@ -333,12 +333,12 @@ mod test { * CX */ // Extension inference here relies on quantum ops not requiring their own test_quantum_extension - let reg = ExtensionRegistry::try_new([ + let reg = ExtensionRegistry::new([ test_quantum_extension::EXTENSION.to_owned(), float_types::EXTENSION.to_owned(), PRELUDE.to_owned(), - ]) - .unwrap(); + ]); + reg.validate()?; let mut outer = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [a, b] = outer.input_wires_arr(); let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?; diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 5d770af4b..43de652a3 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -466,9 +466,8 @@ mod test { #[test] #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-' fn cfg() -> Result<(), Box> { - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]) - .unwrap(); + let reg = ExtensionRegistry::new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]); + reg.validate()?; let listy = list_type(usize_t()); let pop: ExtensionOp = ListOp::pop .with_type(usize_t()) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index e5dc42841..a67ec6f2f 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -793,7 +793,7 @@ pub(in crate::hugr::rewrite) mod test { }; rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); - assert_eq!(hugr.update_validate(&PRELUDE_REGISTRY), Ok(())); + assert_eq!(hugr.update_validate(&test_quantum_extension::REG), Ok(())); assert_eq!(hugr.node_count(), 4); } diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index dba525d39..63be8c109 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -548,7 +548,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(&PRELUDE_ID)} ], &EMPTY_REG).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ], &EMPTY_REG).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index e98744a82..d4309c92d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -388,7 +388,8 @@ fn invalid_types() { ) .unwrap(); }); - let reg = ExtensionRegistry::try_new([ext.clone(), PRELUDE.clone()]).unwrap(); + let reg = ExtensionRegistry::new([ext.clone(), PRELUDE.clone()]); + reg.validate().unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); @@ -569,7 +570,8 @@ fn no_polymorphic_consts() -> Result<(), Box> { .unwrap() .instantiate(vec![TypeArg::new_var_use(0, BOUND)])?, ); - let reg = ExtensionRegistry::try_new([collections::EXTENSION.to_owned()]).unwrap(); + let reg = ExtensionRegistry::new([collections::EXTENSION.to_owned()]); + reg.validate()?; let mut def = FunctionBuilder::new( "myfunc", PolyFuncType::new( @@ -653,7 +655,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs( eval2.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), + &ExtensionRegistry::new([PRELUDE.clone(), e]), )?; Ok(()) } @@ -693,7 +695,7 @@ fn row_variables() -> Result<(), Box> { let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs( par_func.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), + &ExtensionRegistry::new([PRELUDE.clone(), e]), )?; Ok(()) } @@ -780,12 +782,13 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; + let reg = ExtensionRegistry::new([e, PRELUDE.clone()]); + reg.validate()?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), &[TypeArg::Extensions { - es: ExtensionSet::singleton(&PRELUDE_ID), + es: ExtensionSet::singleton(PRELUDE_ID), }], [func, tup], ®, diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index ea23e17de..f6494231a 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -849,14 +849,11 @@ mod tests { dfg.finish_with_outputs([w0, w1, w2])? }; let hugr = mod_builder - .finish_hugr( - &ExtensionRegistry::try_new([ - prelude::PRELUDE.to_owned(), - test_quantum_extension::EXTENSION.to_owned(), - float_types::EXTENSION.to_owned(), - ]) - .unwrap(), - ) + .finish_hugr(&ExtensionRegistry::new([ + prelude::PRELUDE.to_owned(), + test_quantum_extension::EXTENSION.to_owned(), + float_types::EXTENSION.to_owned(), + ])) .map_err(|e| -> BuildError { e.into() })?; Ok((hugr, func_id.node())) } @@ -1139,14 +1136,11 @@ mod tests { let extracted = subgraph.extract_subgraph(&hugr, "region"); extracted - .validate( - &ExtensionRegistry::try_new([ - prelude::PRELUDE.to_owned(), - test_quantum_extension::EXTENSION.to_owned(), - float_types::EXTENSION.to_owned(), - ]) - .unwrap(), - ) + .validate(&ExtensionRegistry::new([ + prelude::PRELUDE.to_owned(), + test_quantum_extension::EXTENSION.to_owned(), + float_types::EXTENSION.to_owned(), + ])) .unwrap(); } diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 002160840..26cd6fd28 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -1068,7 +1068,7 @@ impl<'a> Context<'a> { let ext_ident = IdentList::new(*ext).map_err(|_| { model::ModelError::MalformedName(ext.to_smolstr()) })?; - es.insert(&ext_ident); + es.insert(ext_ident); } model::ExtSetPart::Splice(term_id) => { // The order in an extension set does not matter. diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index d3c24a89a..05f854589 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,8 +9,11 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::extension::resolution::{ + collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, +}; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ExtensionId, ExtensionSet}; +use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; use crate::types::{EdgeKind, Signature}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -309,6 +312,20 @@ impl OpType { _ => None, } } + + /// Returns a registry with all the extensions required by the operation. + /// + /// This includes the operation extension in [`OpType::extension_id`], and any + /// extension required by the operation's signature types. + pub fn used_extensions(&self) -> Result { + // Collect extensions on the types. + let mut reg = collect_op_types_extensions(None, self)?; + // And on the operation definition itself. + if let Some(ext) = collect_op_extension(None, self)? { + reg.register_updated(ext); + } + Ok(reg) + } } /// Macro used by operations that want their diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index f79f05d12..3db924a95 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -594,7 +594,7 @@ mod test { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(self.0.extension()) + ExtensionSet::singleton(self.0.extension().clone()) } fn get_type(&self) -> Type { @@ -614,7 +614,7 @@ mod test { } fn test_registry() -> ExtensionRegistry { - ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap() + ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) } /// Constructs a DFG hugr defining a sum constant, and returning the loaded value. diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index b6755d5c6..f9543cfef 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -41,7 +41,7 @@ use super::{Value, ValueName}; /// #[typetag::serde] /// impl CustomConst for CC { /// fn name(&self) -> ValueName { "CC".into() } -/// fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(&int_types::EXTENSION_ID) } +/// fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(int_types::EXTENSION_ID) } /// fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() } /// } /// diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index be4ec01b9..e6d5793f8 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -116,6 +116,11 @@ impl ExtensionOp { pub fn signature_mut(&mut self) -> &mut Signature { &mut self.signature } + + /// Returns a mutable reference to the type arguments of the operation. + pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] { + self.args.as_mut_slice() + } } impl From for OpaqueOp { @@ -235,6 +240,11 @@ impl OpaqueOp { pub fn extension(&self) -> &ExtensionId { &self.extension } + + /// Returns a mutable reference to the type arguments of the operation. + pub(crate) fn args_mut(&mut self) -> &mut [TypeArg] { + self.args.as_mut_slice() + } } impl DataflowOpTrait for OpaqueOp { @@ -299,7 +309,7 @@ mod test { use ops::OpType; - use crate::extension::resolution::update_op_extensions; + use crate::extension::resolution::resolve_op_extensions; use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY}; use crate::{ extension::{ @@ -349,7 +359,7 @@ mod test { Signature::new(i0.clone(), bool_t()), ); let mut resolved = opaque.into(); - update_op_extensions( + resolve_op_extensions( Node::from(portgraph::NodeIndex::new(1)), &mut resolved, registry, @@ -383,7 +393,8 @@ mod test { }); let ext_id = ext.name().clone(); - let registry = ExtensionRegistry::try_new([ext]).unwrap(); + let registry = ExtensionRegistry::new([ext]); + registry.validate().unwrap(); let opaque_val = OpaqueOp::new( ext_id.clone(), val_name, @@ -393,7 +404,7 @@ mod test { ); let opaque_comp = OpaqueOp::new(ext_id.clone(), comp_name, "".into(), vec![], endo_sig); let mut resolved_val = opaque_val.into(); - update_op_extensions( + resolve_op_extensions( Node::from(portgraph::NodeIndex::new(1)), &mut resolved_val, ®istry, @@ -402,7 +413,7 @@ mod test { assert_eq!(resolve_res_definition(&resolved_val).name(), val_name); let mut resolved_comp = opaque_comp.into(); - update_op_extensions( + resolve_op_extensions( Node::from(portgraph::NodeIndex::new(2)), &mut resolved_comp, ®istry, diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index 2896b5b7a..a2bc40ed0 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -11,7 +11,7 @@ pub mod ptr; /// Extension registry with all standard extensions and prelude. pub fn std_reg() -> ExtensionRegistry { - ExtensionRegistry::try_new([ + let reg = ExtensionRegistry::new([ crate::extension::prelude::PRELUDE.clone(), arithmetic::int_ops::EXTENSION.to_owned(), arithmetic::int_types::EXTENSION.to_owned(), @@ -21,8 +21,10 @@ pub fn std_reg() -> ExtensionRegistry { collections::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), ptr::EXTENSION.to_owned(), - ]) - .unwrap() + ]); + reg.validate() + .expect("Standard extension registry is valid"); + reg } lazy_static::lazy_static! { diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index b8b2771c2..bb3badc57 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -170,13 +170,12 @@ lazy_static! { }; /// Registry of extensions required to validate integer operations. - pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ PRELUDE.clone(), super::int_types::EXTENSION.clone(), super::float_types::EXTENSION.clone(), EXTENSION.clone(), - ]) - .unwrap(); + ]); } impl MakeRegisteredOp for ConvertOpType { diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index dad35d3c7..ce6f30e15 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -111,18 +111,17 @@ lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); FloatOps::load_all_ops(extension, extension_ref).unwrap(); }) }; /// Registry of extensions required to validate float operations. - pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ PRELUDE.clone(), super::float_types::EXTENSION.clone(), EXTENSION.clone(), - ]) - .unwrap(); + ]); } impl MakeRegisteredOp for FloatOps { diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 304f44899..3275548dd 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -84,7 +84,7 @@ impl CustomConst for ConstF64 { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&EXTENSION_ID) + ExtensionSet::singleton(EXTENSION_ID) } } diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 4f6332767..6f4b248d1 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -254,18 +254,17 @@ lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); IntOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; /// Registry of extensions required to validate integer operations. - pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ PRELUDE.clone(), super::int_types::EXTENSION.clone(), EXTENSION.clone(), - ]) - .unwrap(); + ]); } impl HasConcrete for IntOpDef { diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index db52c6576..88ee9d154 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -182,7 +182,7 @@ impl CustomConst for ConstInt { } fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(&EXTENSION_ID) + ExtensionSet::singleton(EXTENSION_ID) } fn get_type(&self) -> Type { diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 492edf428..384c38a50 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -287,11 +287,10 @@ lazy_static! { }; /// Registry of extensions required to validate list operations. - pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + pub static ref COLLECTIONS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ PRELUDE.clone(), EXTENSION.clone(), - ]) - .unwrap(); + ]); } impl MakeRegisteredOp for ListOp { @@ -367,15 +366,14 @@ impl ListOpInst { /// Convert this list operation to an [`ExtensionOp`] by providing a /// registry to validate the element type against. pub fn to_extension_op(self, elem_type_registry: &ExtensionRegistry) -> Option { - let registry = ExtensionRegistry::try_new( + let registry = ExtensionRegistry::new( elem_type_registry .clone() .into_iter() // ignore self if already in registry .filter(|ext| ext.name() != EXTENSION.name()) .chain(std::iter::once(EXTENSION.to_owned())), - ) - .unwrap(); + ); ExtensionOp::new( registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), self.type_args(), @@ -438,9 +436,7 @@ mod test { #[test] fn test_list_ops() { - let reg = - ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]) - .unwrap(); + let reg = ExtensionRegistry::new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]); let pop_op = ListOp::pop.with_type(qb_t()); let pop_ext = pop_op.clone().to_extension_op(®).unwrap(); assert_eq!(ListOpInst::from_extension_op(&pop_ext).unwrap(), pop_op); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index a4f73ed4f..adcb2c970 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -130,8 +130,7 @@ lazy_static! { /// Reference to the logic Extension. pub static ref EXTENSION: Arc = extension(); /// Registry required to validate logic extension. - pub static ref LOGIC_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXTENSION.clone()]).unwrap(); + pub static ref LOGIC_REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone()]); } impl MakeRegisteredOp for LogicOp { diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index 774ea10eb..22587ec3e 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -110,8 +110,7 @@ lazy_static! { /// Reference to the pointer Extension. pub static ref EXTENSION: Arc = extension(); /// Registry required to validate pointer extension. - pub static ref PTR_REG: ExtensionRegistry = - ExtensionRegistry::try_new([EXTENSION.clone()]).unwrap(); + pub static ref PTR_REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone()]); } /// Integer type of a given bit width (specified by the TypeArg). Depending on @@ -272,8 +271,7 @@ pub(crate) mod test { fn test_build() { let in_row = vec![bool_t(), float64_type()]; - let reg = - ExtensionRegistry::try_new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]).unwrap(); + let reg = ExtensionRegistry::new([EXTENSION.to_owned(), FLOAT_EXTENSION.to_owned()]); let hugr = { let mut builder = DFGBuilder::new( Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 24a6f43cd..54ba5dc13 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -17,7 +17,7 @@ use crate::utils::display_list_with_separator; pub use check::SumTypeError; pub use custom::CustomType; pub use poly_func::{PolyFuncType, PolyFuncTypeRV}; -pub use signature::{FuncValueType, Signature}; +pub use signature::{FuncTypeBase, FuncValueType, Signature}; use smol_str::SmolStr; pub use type_param::TypeArg; pub use type_row::{TypeRow, TypeRowRV}; @@ -25,8 +25,6 @@ pub use type_row::{TypeRow, TypeRowRV}; // Unused in --no-features #[allow(unused_imports)] pub(crate) use poly_func::PolyFuncTypeBase; -#[allow(unused_imports)] -pub(crate) use signature::FuncTypeBase; use itertools::FoldWhile::{Continue, Done}; use itertools::{repeat_n, Itertools}; diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 22af9b77a..81bc814ac 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -137,6 +137,11 @@ impl CustomType { &self.args } + /// Returns a mutable reference to the type arguments. + pub(crate) fn args_mut(&mut self) -> &mut Vec { + &mut self.args + } + /// Parent extension. pub fn extension(&self) -> &ExtensionId { &self.extension diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 754e32205..66dff8245 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -177,7 +177,7 @@ pub(crate) mod test { lazy_static! { static ref REGISTRY: ExtensionRegistry = - ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap(); + ExtensionRegistry::new([PRELUDE.to_owned(), EXTENSION.to_owned()]); } impl PolyFuncTypeBase { @@ -345,7 +345,8 @@ pub(crate) mod test { .unwrap(); }); - let reg = ExtensionRegistry::try_new([ext.clone()]).unwrap(); + let reg = ExtensionRegistry::new([ext.clone()]); + reg.validate().unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncTypeBase::new_validated( diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index faf56abd2..c5d3459e1 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -9,6 +9,7 @@ use super::type_row::TypeRowBase; use super::{MaybeRV, NoRV, RowVariable, Substitution, Type, TypeRow}; use crate::core::PortIndex; +use crate::extension::resolution::{collect_signature_exts, ExtensionCollectionError}; use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; use crate::{Direction, IncomingPort, OutgoingPort, Port}; @@ -118,6 +119,28 @@ impl FuncTypeBase { self.output.validate(extension_registry, var_decls)?; self.extension_reqs.validate(var_decls) } + + /// Returns a registry with the concrete extensions used by this signature. + /// + /// Note that extension type parameters are not included, as they have not + /// been instantiated yet. + /// + /// This method only returns extensions actually used by the types in the + /// signature. The extension deltas added via [`Self::with_extension_delta`] + /// refer to _runtime_ extensions, which may not be in all places that + /// manipulate a HUGR. + pub fn used_extensions(&self) -> Result { + let mut used = ExtensionRegistry::default(); + let mut missing = ExtensionSet::new(); + + collect_signature_exts(self, &mut used, &mut missing); + + if missing.is_empty() { + Ok(used) + } else { + Err(ExtensionCollectionError::dropped_signature(self, missing)) + } + } } impl FuncValueType { diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index dba98384c..f0c0d7b97 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -194,13 +194,13 @@ pub(crate) mod test_quantum_extension { pub static ref EXTENSION: Arc = extension(); /// A registry with all necessary extensions to run tests internally, including the test quantum extension. - pub static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([ + pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([ EXTENSION.clone(), PRELUDE.clone(), float_types::EXTENSION.clone(), float_ops::EXTENSION.clone(), logic::EXTENSION.clone() - ]).unwrap(); + ]); } diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 44fe2d410..c85d5cded 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -244,13 +244,10 @@ mod test { let hugr = SimpleHugrConfig::new() .with_ins(vec![t1.clone(), t2.clone()]) .with_outs(t2.clone()) - .with_extensions( - ExtensionRegistry::try_new([ - int_types::EXTENSION.to_owned(), - prelude::PRELUDE.to_owned(), - ]) - .unwrap(), - ) + .with_extensions(ExtensionRegistry::new([ + int_types::EXTENSION.to_owned(), + prelude::PRELUDE.to_owned(), + ])) .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); let mut cfg_builder = builder diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index f1c9eda8c..c8a53902c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -5,7 +5,7 @@ use hugr_core::builder::{ BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, }; use hugr_core::extension::prelude::PRELUDE_ID; -use hugr_core::extension::{ExtensionRegistry, ExtensionSet, EMPTY_REG}; +use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; use hugr_core::ops::handle::FuncID; use hugr_core::std_extensions::arithmetic::{ conversions, float_ops, float_types, int_ops, int_types, @@ -108,7 +108,7 @@ impl SimpleHugrConfig { Self { ins: Default::default(), outs: Default::default(), - extensions: EMPTY_REG, + extensions: Default::default(), } } diff --git a/hugr-llvm/src/extension/collections.rs b/hugr-llvm/src/extension/collections.rs index 60be7e3cd..5a60c6f54 100644 --- a/hugr-llvm/src/extension/collections.rs +++ b/hugr-llvm/src/extension/collections.rs @@ -398,11 +398,11 @@ mod test { &collections::COLLECTIONS_REGISTRY, ) .unwrap(); - let es = ExtensionRegistry::try_new([ + let es = ExtensionRegistry::new([ collections::EXTENSION.to_owned(), prelude::PRELUDE.to_owned(), - ]) - .unwrap(); + ]); + es.validate().unwrap(); let hugr = SimpleHugrConfig::new() .with_ins(ext_op.signature().input().clone()) .with_outs(ext_op.signature().output().clone()) @@ -423,11 +423,11 @@ mod test { fn test_const_list_emmission(mut llvm_ctx: TestContext) { let elem_ty = usize_t(); let contents = (1..4).map(|i| Value::extension(ConstUsize::new(i))); - let es = ExtensionRegistry::try_new([ + let es = ExtensionRegistry::new([ collections::EXTENSION.to_owned(), prelude::PRELUDE.to_owned(), - ]) - .unwrap(); + ]); + es.validate().unwrap(); let hugr = SimpleHugrConfig::new() .with_ins(vec![]) diff --git a/hugr-llvm/src/extension/logic.rs b/hugr-llvm/src/extension/logic.rs index 88dc77a2f..a1d42d3f6 100644 --- a/hugr-llvm/src/extension/logic.rs +++ b/hugr-llvm/src/extension/logic.rs @@ -110,7 +110,7 @@ mod test { SimpleHugrConfig::new() .with_ins(vec![bool_t(); arity]) .with_outs(vec![bool_t()]) - .with_extensions(ExtensionRegistry::try_new(vec![logic::EXTENSION.to_owned()]).unwrap()) + .with_extensions(ExtensionRegistry::new(vec![logic::EXTENSION.to_owned()])) .finish(|mut builder| { let outputs = builder .add_dataflow_op(op, builder.input_wires()) diff --git a/hugr-llvm/src/extension/prelude/array.rs b/hugr-llvm/src/extension/prelude/array.rs index a2037a0f5..787b0a60b 100644 --- a/hugr-llvm/src/extension/prelude/array.rs +++ b/hugr-llvm/src/extension/prelude/array.rs @@ -595,13 +595,12 @@ mod test { } fn exec_registry() -> ExtensionRegistry { - ExtensionRegistry::try_new([ + ExtensionRegistry::new([ int_types::EXTENSION.to_owned(), int_ops::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), prelude::PRELUDE.to_owned(), ]) - .unwrap() } fn exec_extension_set() -> ExtensionSet { diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index a2208430d..1ab6ff41b 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -25,7 +25,7 @@ pub(crate) mod test { lazy_static! { /// A registry containing various extensions for testing. - pub(crate) static ref TEST_REG: ExtensionRegistry = ExtensionRegistry::try_new([ + pub(crate) static ref TEST_REG: ExtensionRegistry = ExtensionRegistry::new([ PRELUDE.to_owned(), arithmetic::int_ops::EXTENSION.to_owned(), arithmetic::int_types::EXTENSION.to_owned(), @@ -34,7 +34,6 @@ pub(crate) mod test { logic::EXTENSION.to_owned(), arithmetic::conversions::EXTENSION.to_owned(), collections::EXTENSION.to_owned(), - ]) - .unwrap(); + ]); } } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 1acda2ba7..15752d8cc 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -231,7 +231,8 @@ mod test { let exit_types: TypeRow = vec![usize_t()].into(); let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e])?; + let reg = ExtensionRegistry::new([PRELUDE.clone(), e]); + reg.validate()?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; @@ -358,7 +359,8 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; + let reg = ExtensionRegistry::new([e, PRELUDE.clone()]); + reg.validate()?; let mut h = h.finish_hugr(®)?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 3a96696e8..e7bb54c93 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -81,8 +81,7 @@ //! lazy_static! { //! /// Quantum extension definition. //! pub static ref EXTENSION: Arc = extension(); -//! pub static ref REG: ExtensionRegistry = -//! ExtensionRegistry::try_new([EXTENSION.clone(), PRELUDE.clone()]).unwrap(); +//! pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone(), PRELUDE.clone()]); //! } //! fn get_gate(gate_name: impl Into) -> ExtensionOp { //! EXTENSION