diff --git a/Cargo.toml b/Cargo.toml index b8f806474..3ecd2a8c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,7 @@ indexmap = "2.3.0" fxhash = "0.2.1" bumpalo = { version = "3.16.0" } pathsearch = "0.2.0" +base64 = "0.22.1" [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 6a08b4a78..828047ea2 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -922,7 +922,7 @@ impl<'a> Context<'a> { model::Term::Var { .. } => Err(error_unsupported!("type variable as `TypeParam`")), model::Term::Apply { .. } => Err(error_unsupported!("custom type as `TypeParam`")), model::Term::ApplyFull { .. } => Err(error_unsupported!("custom type as `TypeParam`")), - + model::Term::BytesType { .. } => Err(error_unsupported!("`bytes` as `TypeParam`")), model::Term::Const { .. } => Err(error_unsupported!("`(const ...)` as `TypeParam`")), model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")), @@ -946,6 +946,7 @@ impl<'a> Context<'a> { | model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } | model::Term::ConstFunc { .. } + | model::Term::Bytes { .. } | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), model::Term::ControlType => { @@ -999,6 +1000,8 @@ impl<'a> Context<'a> { model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), + model::Term::BytesType => Err(error_unsupported!("`bytes` as `TypeArg`")), + model::Term::Bytes { .. } => Err(error_unsupported!("`(bytes ..)` as `TypeArg`")), model::Term::Const { .. } => Err(error_unsupported!("`const` as `TypeArg`")), model::Term::ConstAdt { .. } => Err(error_unsupported!("adt constant as `TypeArg`")), model::Term::ConstFunc { .. } => { @@ -1126,6 +1129,8 @@ impl<'a> Context<'a> { | model::Term::ControlType | model::Term::Nat(_) | model::Term::NonLinearConstraint { .. } + | model::Term::Bytes { .. } + | model::Term::BytesType | model::Term::ConstFunc { .. } | model::Term::ConstAdt { .. } => Err(model::ModelError::TypeError(term_id).into()), } @@ -1363,6 +1368,8 @@ impl<'a> Context<'a> { | model::Term::Control { .. } | model::Term::ControlType | model::Term::Type + | model::Term::Bytes { .. } + | model::Term::BytesType | model::Term::NonLinearConstraint { .. } => { Err(model::ModelError::TypeError(term_id).into()) } diff --git a/hugr-model/Cargo.toml b/hugr-model/Cargo.toml index e23ef3a6e..d1e3636b9 100644 --- a/hugr-model/Cargo.toml +++ b/hugr-model/Cargo.toml @@ -16,6 +16,7 @@ license.workspace = true bench = false [dependencies] +base64 = { workspace = true } bumpalo = { workspace = true, features = ["collections"] } capnp = "0.20.1" derive_more = { version = "1.0.0", features = ["display"] } diff --git a/hugr-model/capnp/hugr-v0.capnp b/hugr-model/capnp/hugr-v0.capnp index 060e8af3b..a06665d1d 100644 --- a/hugr-model/capnp/hugr-v0.capnp +++ b/hugr-model/capnp/hugr-v0.capnp @@ -157,6 +157,8 @@ struct Term { nonLinearConstraint @20 :TermId; constFunc @22 :RegionId; constAdt @23 :ConstAdt; + bytes @24 :Data; + bytesType @25 :Void; } struct Apply { diff --git a/hugr-model/src/v0/binary/read.rs b/hugr-model/src/v0/binary/read.rs index 2ea0e7742..69ef45b14 100644 --- a/hugr-model/src/v0/binary/read.rs +++ b/hugr-model/src/v0/binary/read.rs @@ -335,6 +335,11 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult let values = model::TermId(reader.get_values()); model::Term::ConstAdt { tag, values } } + + Which::Bytes(bytes) => model::Term::Bytes { + data: bump.alloc_slice_copy(bytes?), + }, + Which::BytesType(()) => model::Term::BytesType, }) } diff --git a/hugr-model/src/v0/binary/write.rs b/hugr-model/src/v0/binary/write.rs index 3a1e1beba..56511ba12 100644 --- a/hugr-model/src/v0/binary/write.rs +++ b/hugr-model/src/v0/binary/write.rs @@ -217,6 +217,14 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) { builder.set_tag(*tag); builder.set_values(values.0); } + + model::Term::Bytes { data } => { + builder.set_bytes(data); + } + + model::Term::BytesType => { + builder.set_bytes_type(()); + } } } diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index f9da742e9..a5c98c8fb 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -686,6 +686,15 @@ pub enum Term<'a> { /// The values of the variant. values: TermId, }, + + /// A literal byte string. + Bytes { + /// The data of the byte string. + data: &'a [u8], + }, + + /// The type of byte strings. + BytesType, } /// A part of a list term. diff --git a/hugr-model/src/v0/text/hugr.pest b/hugr-model/src/v0/text/hugr.pest index 3d37b9878..1e12b1839 100644 --- a/hugr-model/src/v0/text/hugr.pest +++ b/hugr-model/src/v0/text/hugr.pest @@ -10,6 +10,8 @@ string_raw = @{ (!("\\" | "\"") ~ ANY)+ } string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") } string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" } +base64_string = { "\"" ~ (ASCII_ALPHANUMERIC | "+" | "/")* ~ "="* ~ "\"" } + module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI } meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" } @@ -97,6 +99,8 @@ term = { | term_non_linear | term_const_func | term_const_adt + | term_bytes_type + | term_bytes } term_wildcard = { "_" } @@ -122,5 +126,7 @@ term_ctrl_type = { "ctrl" } term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" } term_const_func = { "(" ~ "fn" ~ term ~ ")" } term_const_adt = { "(" ~ "tag" ~ tag ~ term* ~ ")" } +term_bytes_type = { "bytes" } +term_bytes = { "(" ~ "bytes" ~ base64_string ~ ")" } spliced_term = { term ~ "..." } diff --git a/hugr-model/src/v0/text/parse.rs b/hugr-model/src/v0/text/parse.rs index 34435b2a3..09d6413c9 100644 --- a/hugr-model/src/v0/text/parse.rs +++ b/hugr-model/src/v0/text/parse.rs @@ -1,3 +1,4 @@ +use base64::{prelude::BASE64_STANDARD, Engine}; use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump}; use fxhash::FxHashMap; use pest::{ @@ -262,6 +263,20 @@ impl<'a> ParseContext<'a> { Term::ConstAdt { tag, values } } + Rule::term_bytes_type => Term::BytesType, + + Rule::term_bytes => { + let token = inner.next().unwrap(); + let slice = token.as_str(); + // Remove the quotes + let slice = &slice[1..slice.len() - 1]; + let data = BASE64_STANDARD.decode(slice).map_err(|_| { + ParseError::custom("invalid base64 encoding", token.as_span()) + })?; + let data = self.bump.alloc_slice_copy(&data); + Term::Bytes { data } + } + r => unreachable!("term: {:?}", r), }; diff --git a/hugr-model/src/v0/text/print.rs b/hugr-model/src/v0/text/print.rs index 430f07e99..53bcc37b7 100644 --- a/hugr-model/src/v0/text/print.rs +++ b/hugr-model/src/v0/text/print.rs @@ -1,3 +1,4 @@ +use base64::{prelude::BASE64_STANDARD, Engine}; use pretty::{Arena, DocAllocator, RefDoc}; use std::borrow::Cow; @@ -598,6 +599,15 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { this.print_text(tag.to_string()); this.print_term(*values) }), + Term::BytesType => { + self.print_text("bytes"); + Ok(()) + } + Term::Bytes { data } => self.print_parens(|this| { + this.print_text("bytes"); + this.print_byte_string(data); + Ok(()) + }), } } @@ -717,4 +727,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> { output.push('"'); self.print_text(output); } + + /// Print a bytes literal. + fn print_byte_string(&mut self, bytes: &[u8]) { + // every 3 bytes are encoded into 4 characters + let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4); + output.push('"'); + BASE64_STANDARD.encode_string(bytes, &mut output); + output.push('"'); + self.print_text(output); + } } diff --git a/hugr-model/tests/binary.rs b/hugr-model/tests/binary.rs index 6a00a6b35..6120abb60 100644 --- a/hugr-model/tests/binary.rs +++ b/hugr-model/tests/binary.rs @@ -68,3 +68,8 @@ pub fn test_lists() { pub fn test_const() { binary_roundtrip(include_str!("fixtures/model-const.edn")); } + +#[test] +pub fn test_literals() { + binary_roundtrip(include_str!("fixtures/model-literals.edn")); +} diff --git a/hugr-model/tests/fixtures/model-literals.edn b/hugr-model/tests/fixtures/model-literals.edn index 552155dda..68087e37f 100644 --- a/hugr-model/tests/fixtures/model-literals.edn +++ b/hugr-model/tests/fixtures/model-literals.edn @@ -1,3 +1,4 @@ (hugr 0) (define-alias mod.string str "\"\n\r\t\\\u{1F44D}") +(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig==")) diff --git a/hugr-model/tests/snapshots/text__literals.snap b/hugr-model/tests/snapshots/text__literals.snap index 4a639d8e7..e767e310a 100644 --- a/hugr-model/tests/snapshots/text__literals.snap +++ b/hugr-model/tests/snapshots/text__literals.snap @@ -5,3 +5,5 @@ expression: "roundtrip(include_str!(\"fixtures/model-literals.edn\"))" (hugr 0) (define-alias mod.string str "\"\n\r\t\\👍") + +(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig=="))