Skip to content

Commit

Permalink
Add a ExtensionRegistry to each Hugr
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Dec 4, 2024
1 parent 3be18e9 commit 1298826
Show file tree
Hide file tree
Showing 52 changed files with 814 additions and 1,019 deletions.
49 changes: 34 additions & 15 deletions hugr-cli/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use clap_verbosity_flag::{InfoLevel, Verbosity};
use clio::Input;
use derive_more::{Display, Error, From};
use hugr::extension::ExtensionRegistry;
use hugr::package::PackageValidationError;
use hugr::package::{PackageEncodingError, PackageValidationError};
use hugr::Hugr;
use std::io::{Cursor, Read, Seek, SeekFrom};
use std::{ffi::OsString, path::PathBuf};

pub mod extensions;
Expand Down Expand Up @@ -46,6 +47,9 @@ pub enum CliError {
/// Error parsing input.
#[display("Error parsing package: {_0}")]
Parse(serde_json::Error),
/// Hugr load error.
#[display("Error parsing package: {_0}")]
HUGRLoad(PackageEncodingError),
#[display("Error validating HUGR: {_0}")]
/// Errors produced by the `validate` subcommand.
Validate(PackageValidationError),
Expand Down Expand Up @@ -96,15 +100,10 @@ impl PackageOrHugr {
}

/// Validates the package or hugr.
///
/// Updates the extension registry with any new extensions defined in the package.
pub fn update_validate(
&mut self,
reg: &mut ExtensionRegistry,
) -> Result<(), PackageValidationError> {
pub fn validate(&self) -> Result<(), PackageValidationError> {
match self {
PackageOrHugr::Package(pkg) => pkg.update_validate(reg),
PackageOrHugr::Hugr(hugr) => hugr.update_validate(reg).map_err(Into::into),
PackageOrHugr::Package(pkg) => pkg.validate(),
PackageOrHugr::Hugr(hugr) => Ok(hugr.validate()?),
}
}
}
Expand All @@ -120,13 +119,33 @@ impl AsRef<[Hugr]> for PackageOrHugr {

impl HugrArgs {
/// Read either a package or a single hugr from the input.
pub fn get_package_or_hugr(&mut self) -> Result<PackageOrHugr, CliError> {
let val: serde_json::Value = serde_json::from_reader(&mut self.input)?;
if let Ok(hugr) = serde_json::from_value::<Hugr>(val.clone()) {
return Ok(PackageOrHugr::Hugr(hugr));
pub fn get_package_or_hugr(
&mut self,
extensions: &ExtensionRegistry,
) -> Result<PackageOrHugr, CliError> {
// We need to read the input twice; once to try to load it as a HUGR, and if that fails, as a package.
// If `input` is a file, we can reuse the reader by seeking back to the start.
// Else, we need to read the file into a buffer.
trait SeekRead: Seek + Read {}
impl<T: Seek + Read> SeekRead for T {}

let mut buffer = Vec::new();
let mut seekable_input: Box<dyn SeekRead> = match self.input.can_seek() {
true => Box::new(&mut self.input),
false => {
self.input.read_to_end(&mut buffer)?;
Box::new(Cursor::new(buffer))
}
};

match Hugr::load_json(&mut seekable_input, extensions) {
Ok(hugr) => Ok(PackageOrHugr::Hugr(hugr)),
Err(_) => {
seekable_input.seek(SeekFrom::Start(0))?;
let pkg = Package::from_json_reader(seekable_input, extensions)?;
Ok(PackageOrHugr::Package(pkg))
}
}
let pkg = serde_json::from_value::<Package>(val.clone())?;
Ok(PackageOrHugr::Package(pkg))
}

/// Read either a package from the input.
Expand Down
16 changes: 14 additions & 2 deletions hugr-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
use clap::Parser as _;

use hugr_cli::{validate, CliArgs};
use hugr_cli::{mermaid, validate, CliArgs};

use clap_verbosity_flag::log::Level;

fn main() {
match CliArgs::parse() {
CliArgs::Validate(args) => run_validate(args),
CliArgs::GenExtensions(args) => args.run_dump(&hugr::std_extensions::STD_REG),
CliArgs::Mermaid(mut args) => args.run_print().unwrap(),
CliArgs::Mermaid(args) => run_mermaid(args),
CliArgs::External(_) => {
// TODO: Implement support for external commands.
// Running `hugr COMMAND` would look for `hugr-COMMAND` in the path
Expand All @@ -36,3 +36,15 @@ fn run_validate(mut args: validate::ValArgs) {
std::process::exit(1);
}
}

/// Run the `mermaid` subcommand.
fn run_mermaid(mut args: mermaid::MermaidArgs) {
let result = args.run_print();

if let Err(e) = result {
if args.hugr_args.verbosity(Level::Error) {
eprintln!("{}", e);
}
std::process::exit(1);
}
}
7 changes: 5 additions & 2 deletions hugr-cli/src/mermaid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ impl MermaidArgs {
/// Write the mermaid diagram to the output.
pub fn run_print(&mut self) -> Result<(), crate::CliError> {
let hugrs = if self.validate {
self.hugr_args.validate()?.0
self.hugr_args.validate()?
} else {
self.hugr_args.get_package_or_hugr()?.into_hugrs()
let extensions = self.hugr_args.extensions()?;
self.hugr_args
.get_package_or_hugr(&extensions)?
.into_hugrs()
};

for hugr in hugrs {
Expand Down
27 changes: 12 additions & 15 deletions hugr-cli/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@ use hugr::{extension::ExtensionRegistry, Extension, Hugr};

use crate::{CliError, HugrArgs};

// TODO: Deprecated re-export. Remove on a breaking release.
#[doc(inline)]
#[deprecated(
since = "0.13.2",
note = "Use `hugr::package::PackageValidationError` instead."
)]
pub use hugr::package::PackageValidationError as ValError;

/// Validate and visualise a HUGR file.
#[derive(Parser, Debug)]
#[clap(version = "1.0", long_about = None)]
Expand All @@ -31,7 +23,7 @@ pub const VALID_PRINT: &str = "HUGR valid!";

impl ValArgs {
/// Run the HUGR cli and validate against an extension registry.
pub fn run(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
pub fn run(&mut self) -> Result<Vec<Hugr>, CliError> {
let result = self.hugr_args.validate()?;
if self.verbosity(Level::Info) {
eprintln!("{}", VALID_PRINT);
Expand All @@ -50,24 +42,29 @@ impl HugrArgs {
///
/// Returns the validated modules and the extension registry the modules
/// were validated against.
pub fn validate(&mut self) -> Result<(Vec<Hugr>, ExtensionRegistry), CliError> {
let mut package = self.get_package_or_hugr()?;
pub fn validate(&mut self) -> Result<Vec<Hugr>, CliError> {
let reg = self.extensions()?;
let package = self.get_package_or_hugr(&reg)?;

package.validate()?;
Ok(package.into_hugrs())
}

let mut reg: ExtensionRegistry = if self.no_std {
/// Return a register with the selected extensions.
pub fn extensions(&self) -> Result<ExtensionRegistry, CliError> {
let mut reg = if self.no_std {
hugr::extension::PRELUDE_REGISTRY.to_owned()
} else {
hugr::std_extensions::STD_REG.to_owned()
};

// register external extensions
for ext in &self.extensions {
let f = std::fs::File::open(ext)?;
let ext: Extension = serde_json::from_reader(f)?;
reg.register_updated(ext);
}

package.update_validate(&mut reg)?;
Ok((package.into_hugrs(), reg))
Ok(reg)
}

/// Test whether a `level` message should be output.
Expand Down
12 changes: 5 additions & 7 deletions hugr-cli/tests/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
//! calling the CLI binary, which Miri doesn't support.
#![cfg(all(test, not(miri)))]

use std::sync::Arc;

use assert_cmd::Command;
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
Expand Down Expand Up @@ -49,9 +47,7 @@ fn test_package(#[default(bool_t())] id_type: Type) -> Package {
df.finish_with_outputs([i]).unwrap();
let hugr = module.hugr().clone(); // unvalidated

let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
let float_ext: Arc<hugr::Extension> = serde_json::from_reader(rdr).unwrap();
Package::new(vec![hugr], vec![float_ext]).unwrap()
Package::new(vec![hugr]).unwrap()
}

/// A DFG-rooted HUGR.
Expand Down Expand Up @@ -130,7 +126,9 @@ fn test_mermaid_invalid(bad_hugr_string: String, mut cmd: Command) {
cmd.arg("mermaid");
cmd.arg("--validate");
cmd.write_stdin(bad_hugr_string);
cmd.assert().failure().stderr(contains("UnconnectedPort"));
cmd.assert()
.failure()
.stderr(contains("has an unconnected port"));
}

#[rstest]
Expand All @@ -141,7 +139,7 @@ fn test_bad_hugr(bad_hugr_string: String, mut val_cmd: Command) {
val_cmd
.assert()
.failure()
.stderr(contains("Error validating HUGR").and(contains("unconnected port")));
.stderr(contains("Node(1)").and(contains("unconnected port")));
}

#[rstest]
Expand Down
15 changes: 7 additions & 8 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@
//! // Finish building the HUGR, consuming the builder.
//! //
//! // Requires a registry with all the extensions used in the module.
//! module_builder.finish_hugr(&LOGIC_REG)
//! module_builder.finish_hugr()
//! }?;
//!
//! // The built HUGR is always valid.
//! hugr.validate(&LOGIC_REG).unwrap_or_else(|e| {
//! hugr.validate().unwrap_or_else(|e| {
//! panic!("HUGR validation failed: {e}");
//! });
//! # Ok(())
Expand Down Expand Up @@ -242,7 +242,6 @@ pub(crate) mod test {
use crate::hugr::{views::HugrView, HugrMut};
use crate::ops;
use crate::types::{PolyFuncType, Signature};
use crate::utils::test_quantum_extension;
use crate::Hugr;

use super::handle::BuildHandle;
Expand All @@ -269,38 +268,38 @@ pub(crate) mod test {

f(f_builder)?;

Ok(module_builder.finish_hugr(&test_quantum_extension::REG)?)
Ok(module_builder.finish_hugr()?)
}

#[fixture]
pub(crate) fn simple_dfg_hugr() -> Hugr {
let dfg_builder = DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_prelude_hugr_with_outputs([i1]).unwrap()
dfg_builder.finish_hugr_with_outputs([i1]).unwrap()
}

#[fixture]
pub(crate) fn simple_funcdef_hugr() -> Hugr {
let fn_builder =
FunctionBuilder::new("test", Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = fn_builder.input_wires_arr();
fn_builder.finish_prelude_hugr_with_outputs([i1]).unwrap()
fn_builder.finish_hugr_with_outputs([i1]).unwrap()
}

#[fixture]
pub(crate) fn simple_module_hugr() -> Hugr {
let mut builder = ModuleBuilder::new();
let sig = Signature::new(vec![bool_t()], vec![bool_t()]);
builder.declare("test", sig.into()).unwrap();
builder.finish_prelude_hugr().unwrap()
builder.finish_hugr().unwrap()
}

#[fixture]
pub(crate) fn simple_cfg_hugr() -> Hugr {
let mut cfg_builder =
CFGBuilder::new(Signature::new(vec![usize_t()], vec![usize_t()])).unwrap();
super::cfg::test::build_basic_cfg(&mut cfg_builder).unwrap();
cfg_builder.finish_prelude_hugr().unwrap()
cfg_builder.finish_hugr().unwrap()
}

/// A helper method which creates a DFG rooted hugr with Input and Output node
Expand Down
Loading

0 comments on commit 1298826

Please sign in to comment.