Skip to content

Commit

Permalink
feat!: used_extensions calls for both ops and signatures (#1739)
Browse files Browse the repository at this point in the history
Adds methods
- `OpType::used_extensions(&self) -> Result<ExtensionRegistry, _>`
- `Signature::used_extensions(&self) -> Result<ExtensionRegistry, _>`
and tests these along with the code merged in #1735.

Moves the code from #1735 into `resolution::types_mut`, and adds a
(quite-similar) non-mutable version in `::types` that only collects the
extensions without modifying the `OpType`.

Fixes the resolution not exploring types inside a `CustomType` type
arguments.

drive-by: Implement `Display`, `Serialize`, and `::new` for
`ExtensionRegistry`.
drive-by: `ExtensionSet` should take ids by value when inserting.
drive-by: Fix `Hugr::resolve_extension_defs` not scanning all the ops.

These changes were extracted from #1738.

BREAKING CHANGE: Removed `ExtensionRegistry::try_new`. Use `new`
instead, and call `ExtensionRegistry::validate` to validate.
BREAKING CHANGE: `ExtensionSet::insert` and `singleton` take extension
ids by value instead of cloning internally.
  • Loading branch information
aborgna-q authored Dec 10, 2024
1 parent 301f61d commit 1443284
Show file tree
Hide file tree
Showing 45 changed files with 1,097 additions and 305 deletions.
180 changes: 133 additions & 47 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@
//! TODO: YAML declaration and parsing. This should be similar to a plugin
//! system (outside the `types` module), which also parses nested [`OpDef`]s.
use itertools::Itertools;
pub use semver::Version;
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::{Debug, Display, Formatter};
use std::fmt::Debug;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};

use derive_more::Display;
use thiserror::Error;

use crate::hugr::IdentList;
Expand Down Expand Up @@ -40,41 +44,73 @@ pub use type_def::{TypeDef, TypeDefBound};
pub mod declarative;

/// Extension Registries store extensions to be looked up e.g. during validation.
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);
#[derive(Debug, Display, Default)]
#[display("ExtensionRegistry[{}]", exts.keys().join(", "))]
pub struct ExtensionRegistry {
/// The extensions in the registry.
exts: BTreeMap<ExtensionId, Arc<Extension>>,
/// A flag indicating whether the current set of extensions has been
/// validated.
///
/// This is used to avoid re-validating the extensions every time the
/// registry is validated, and is set to `false` whenever a new extension is
/// added.
valid: AtomicBool,
}

impl PartialEq for ExtensionRegistry {
fn eq(&self, other: &Self) -> bool {
self.exts == other.exts
}
}

impl Clone for ExtensionRegistry {
fn clone(&self) -> Self {
Self {
exts: self.exts.clone(),
valid: self.valid.load(Ordering::Relaxed).into(),
}
}
}

impl ExtensionRegistry {
/// Create a new empty extension registry.
pub fn new(extensions: impl IntoIterator<Item = Arc<Extension>>) -> Self {
let mut res = Self::default();
for ext in extensions.into_iter() {
res.register_updated(ext);
}
res
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.0.get(name)
self.exts.get(name)
}

/// Returns `true` if the registry contains an extension with the given name.
pub fn contains(&self, name: &str) -> bool {
self.0.contains_key(name)
self.exts.contains_key(name)
}

/// Makes a new [ExtensionRegistry], validating all the extensions in it.
pub fn try_new(
value: impl IntoIterator<Item = Arc<Extension>>,
) -> Result<Self, ExtensionRegistryError> {
let mut res = ExtensionRegistry(BTreeMap::new());

for ext in value.into_iter() {
res.register(ext)?;
/// Validate the set of extensions, ensuring that each extension requirements are also in the registry.
///
/// Note this potentially asks extensions to validate themselves against other extensions that
/// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
/// or at least to validate the types first - which we don't do at all yet:
//
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
pub fn validate(&self) -> Result<(), ExtensionRegistryError> {
if self.valid.load(Ordering::Relaxed) {
return Ok(());
}

// Note this potentially asks extensions to validate themselves against other extensions that
// may *not* be valid themselves yet. It'd be better to order these respecting dependencies,
// or at least to validate the types first - which we don't do at all yet:
// TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be
// cyclically dependent, so there is no perfect solution, and this is at least simple.
for ext in res.0.values() {
ext.validate(&res)
for ext in self.exts.values() {
ext.validate(self)
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
}

Ok(res)
self.valid.store(true, Ordering::Relaxed);
Ok(())
}

/// Registers a new extension to the registry.
Expand All @@ -85,14 +121,17 @@ impl ExtensionRegistry {
extension: impl Into<Arc<Extension>>,
) -> Result<(), ExtensionRegistryError> {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
extension.name().clone(),
prev.get().version().clone(),
extension.version().clone(),
)),
btree_map::Entry::Vacant(ve) => {
ve.insert(extension);
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);

Ok(())
}
}
Expand All @@ -109,7 +148,7 @@ impl ExtensionRegistry {
/// see [`ExtensionRegistry::register_updated_ref`].
pub fn register_updated(&mut self, extension: impl Into<Arc<Extension>>) {
let extension = extension.into();
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension;
Expand All @@ -119,6 +158,8 @@ impl ExtensionRegistry {
ve.insert(extension);
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Registers a new extension to the registry, keeping the one most up to
Expand All @@ -131,7 +172,7 @@ impl ExtensionRegistry {
/// Clones the Arc only when required. For no-cloning version see
/// [`ExtensionRegistry::register_updated`].
pub fn register_updated_ref(&mut self, extension: &Arc<Extension>) {
match self.0.entry(extension.name().clone()) {
match self.exts.entry(extension.name().clone()) {
btree_map::Entry::Occupied(mut prev) => {
if prev.get().version() < extension.version() {
*prev.get_mut() = extension.clone();
Expand All @@ -141,31 +182,36 @@ impl ExtensionRegistry {
ve.insert(extension.clone());
}
}
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);
}

/// Returns the number of extensions in the registry.
pub fn len(&self) -> usize {
self.0.len()
self.exts.len()
}

/// Returns `true` if the registry contains no extensions.
pub fn is_empty(&self) -> bool {
self.0.is_empty()
self.exts.is_empty()
}

/// Returns an iterator over the extensions in the registry.
pub fn iter(&self) -> <&Self as IntoIterator>::IntoIter {
self.0.values()
self.exts.values()
}

/// Returns an iterator over the extensions ids in the registry.
pub fn ids(&self) -> impl Iterator<Item = &ExtensionId> {
self.0.keys()
self.exts.keys()
}

/// Delete an extension from the registry and return it if it was present.
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
self.0.remove(name)
// Clear the valid flag so that the registry is re-validated.
self.valid.store(false, Ordering::Relaxed);

self.exts.remove(name)
}
}

Expand All @@ -175,7 +221,7 @@ impl IntoIterator for ExtensionRegistry {
type IntoIter = std::collections::btree_map::IntoValues<ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_values()
self.exts.into_values()
}
}

Expand All @@ -185,7 +231,7 @@ impl<'a> IntoIterator for &'a ExtensionRegistry {
type IntoIter = std::collections::btree_map::Values<'a, ExtensionId, Arc<Extension>>;

fn into_iter(self) -> Self::IntoIter {
self.0.values()
self.exts.values()
}
}

Expand All @@ -205,8 +251,33 @@ impl Extend<Arc<Extension>> for ExtensionRegistry {
}
}

// Encode/decode ExtensionRegistry as a list of extensions.
// We can get the map key from the extension itself.
impl<'de> Deserialize<'de> for ExtensionRegistry {
fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
where
D: Deserializer<'de>,
{
let extensions: Vec<Arc<Extension>> = Vec::deserialize(deserializer)?;
Ok(ExtensionRegistry::new(extensions))
}
}

impl Serialize for ExtensionRegistry {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let extensions: Vec<Arc<Extension>> = self.exts.values().cloned().collect();
extensions.serialize(serializer)
}
}

/// An Extension Registry containing no extensions.
pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new());
pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry {
exts: BTreeMap::new(),
valid: AtomicBool::new(true),
};

/// An error that can occur in computing the signature of a node.
/// TODO: decide on failure modes
Expand All @@ -226,7 +297,7 @@ pub enum SignatureError {
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
/// The Extension Registry did not contain an Extension referenced by the Signature
#[error("Extension '{missing}' not found. Available extensions: {}",
#[error("Extension '{missing}' is not part of the declared HUGR extensions [{}]",
available.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(", ")
)]
ExtensionNotFound {
Expand Down Expand Up @@ -614,7 +685,10 @@ pub enum ExtensionBuildError {
}

/// A set of extensions identified by their unique [`ExtensionId`].
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(
Clone, Debug, Display, Default, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize,
)]
#[display("[{}]", _0.iter().join(", "))]
pub struct ExtensionSet(BTreeSet<ExtensionId>);

/// A special ExtensionId which indicates that the delta of a non-Function
Expand All @@ -632,7 +706,7 @@ impl ExtensionSet {
}

/// Adds a extension to the set.
pub fn insert(&mut self, extension: &ExtensionId) {
pub fn insert(&mut self, extension: ExtensionId) {
self.0.insert(extension.clone());
}

Expand Down Expand Up @@ -660,7 +734,7 @@ impl ExtensionSet {
}

/// Create a extension set with a single element.
pub fn singleton(extension: &ExtensionId) -> Self {
pub fn singleton(extension: ExtensionId) -> Self {
let mut set = Self::new();
set.insert(extension);
set
Expand Down Expand Up @@ -724,7 +798,25 @@ impl ExtensionSet {

impl From<ExtensionId> for ExtensionSet {
fn from(id: ExtensionId) -> Self {
Self::singleton(&id)
Self::singleton(id)
}
}

impl IntoIterator for ExtensionSet {
type Item = ExtensionId;
type IntoIter = std::collections::btree_set::IntoIter<ExtensionId>;

fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

impl<'a> IntoIterator for &'a ExtensionSet {
type Item = &'a ExtensionId;
type IntoIter = std::collections::btree_set::Iter<'a, ExtensionId>;

fn into_iter(self) -> Self::IntoIter {
self.0.iter()
}
}

Expand All @@ -738,12 +830,6 @@ fn as_typevar(e: &ExtensionId) -> Option<usize> {
}
}

impl Display for ExtensionSet {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
f.debug_list().entries(self.0.iter()).finish()
}
}

impl FromIterator<ExtensionId> for ExtensionSet {
fn from_iter<I: IntoIterator<Item = ExtensionId>>(iter: I) -> Self {
Self(BTreeSet::from_iter(iter))
Expand Down Expand Up @@ -783,8 +869,8 @@ pub mod test {
fn test_register_update() {
// Two registers that should remain the same.
// We use them to test both `register_updated` and `register_updated_ref`.
let mut reg = ExtensionRegistry::try_new([]).unwrap();
let mut reg_ref = ExtensionRegistry::try_new([]).unwrap();
let mut reg = ExtensionRegistry::default();
let mut reg_ref = ExtensionRegistry::default();

let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
Expand Down
4 changes: 2 additions & 2 deletions hugr-core/src/extension/declarative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ impl ExtensionSetDeclaration {
registry.register(PRELUDE.clone())?;
}
if !scope.contains(&PRELUDE_ID) {
scope.insert(&PRELUDE_ID);
scope.insert(PRELUDE_ID);
}

// Registers extensions sequentially, adding them to the current scope.
Expand All @@ -137,7 +137,7 @@ impl ExtensionSetDeclaration {
registry,
};
let ext = decl.make_extension(&self.imports, ctx)?;
scope.insert(ext.name());
scope.insert(ext.name().clone());
registry.register(ext)?;
}

Expand Down
7 changes: 4 additions & 3 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ impl SignatureFunc {
SignatureFunc::MissingValidateFunc(ts) => (ts, args),
};
let mut res = pf.instantiate(args, exts)?;
res.extension_reqs.insert(&def.extension);
res.extension_reqs.insert(def.extension.clone());

// If there are any row variables left, this will fail with an error:
res.try_into()
Expand Down Expand Up @@ -658,7 +658,8 @@ pub(super) mod test {
Ok(())
})?;

let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap();
let reg = ExtensionRegistry::new([PRELUDE.clone(), EXTENSION.clone(), ext]);
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Expand Down Expand Up @@ -822,7 +823,7 @@ pub(super) mod test {
)?;

// Concrete extension set
let es = ExtensionSet::singleton(&EXT_ID);
let es = ExtensionSet::singleton(EXT_ID);
let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone());
let args = [TypeArg::Extensions { es }];

Expand Down
Loading

0 comments on commit 1443284

Please sign in to comment.