Skip to content

Commit

Permalink
Bytes literal in hugr-model.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Jan 8, 2025
1 parent b36d97d commit 4b31d2b
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions hugr-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ license.workspace = true
bench = false

[dependencies]
base64 = { workspace = true }
bumpalo = { workspace = true, features = ["collections"] }
capnp = "0.20.1"
derive_more = { version = "1.0.0", features = ["display"] }
Expand Down
2 changes: 2 additions & 0 deletions hugr-model/capnp/hugr-v0.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ struct Term {
nonLinearConstraint @20 :TermId;
constFunc @22 :RegionId;
constAdt @23 :ConstAdt;
bytes @24 :Data;
bytesType @25 :Void;
}

struct Apply {
Expand Down
5 changes: 5 additions & 0 deletions hugr-model/src/v0/binary/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ fn read_term<'a>(bump: &'a Bump, reader: hugr_capnp::term::Reader) -> ReadResult
let values = model::TermId(reader.get_values());
model::Term::ConstAdt { tag, values }
}

Which::Bytes(bytes) => model::Term::Bytes {
data: bump.alloc_slice_copy(bytes?),
},
Which::BytesType(()) => model::Term::BytesType,
})
}

Expand Down
8 changes: 8 additions & 0 deletions hugr-model/src/v0/binary/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,14 @@ fn write_term(mut builder: hugr_capnp::term::Builder, term: &model::Term) {
builder.set_tag(*tag);
builder.set_values(values.0);
}

model::Term::Bytes { data } => {
builder.set_bytes(data);
}

model::Term::BytesType => {
builder.set_bytes_type(());
}
}
}

Expand Down
9 changes: 9 additions & 0 deletions hugr-model/src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,15 @@ pub enum Term<'a> {
/// The values of the variant.
values: TermId,
},

/// A literal byte string.
Bytes {
/// The data of the byte string.
data: &'a [u8],
},

/// The type of byte strings.
BytesType,
}

/// A part of a list term.
Expand Down
6 changes: 6 additions & 0 deletions hugr-model/src/v0/text/hugr.pest
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ string_raw = @{ (!("\\" | "\"") ~ ANY)+ }
string_escape = @{ "\\" ~ ("\"" | "\\" | "n" | "r" | "t") }
string_unicode = @{ "\\u" ~ "{" ~ ASCII_HEX_DIGIT+ ~ "}" }

base64_string = { "\"" ~ (ASCII_ALPHANUMERIC | "+" | "/")* ~ "="* ~ "\"" }

module = { "(" ~ "hugr" ~ "0" ~ ")" ~ meta* ~ node* ~ EOI }

meta = { "(" ~ "meta" ~ symbol ~ term ~ ")" }
Expand Down Expand Up @@ -97,6 +99,8 @@ term = {
| term_non_linear
| term_const_func
| term_const_adt
| term_bytes_type
| term_bytes
}

term_wildcard = { "_" }
Expand All @@ -122,5 +126,7 @@ term_ctrl_type = { "ctrl" }
term_non_linear = { "(" ~ "nonlinear" ~ term ~ ")" }
term_const_func = { "(" ~ "fn" ~ term ~ ")" }
term_const_adt = { "(" ~ "tag" ~ tag ~ term* ~ ")" }
term_bytes_type = { "bytes" }
term_bytes = { "(" ~ "bytes" ~ base64_string ~ ")" }

spliced_term = { term ~ "..." }
15 changes: 15 additions & 0 deletions hugr-model/src/v0/text/parse.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use bumpalo::{collections::String as BumpString, collections::Vec as BumpVec, Bump};
use fxhash::FxHashMap;
use pest::{
Expand Down Expand Up @@ -262,6 +263,20 @@ impl<'a> ParseContext<'a> {
Term::ConstAdt { tag, values }
}

Rule::term_bytes_type => Term::BytesType,

Rule::term_bytes => {
let token = inner.next().unwrap();
let slice = token.as_str();
// Remove the quotes
let slice = &slice[1..slice.len() - 1];
let data = BASE64_STANDARD.decode(slice).map_err(|_| {
ParseError::custom("invalid base64 encoding", token.as_span())
})?;
let data = self.bump.alloc_slice_copy(&data);
Term::Bytes { data }
}

r => unreachable!("term: {:?}", r),
};

Expand Down
20 changes: 20 additions & 0 deletions hugr-model/src/v0/text/print.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::{prelude::BASE64_STANDARD, Engine};
use pretty::{Arena, DocAllocator, RefDoc};
use std::borrow::Cow;

Expand Down Expand Up @@ -598,6 +599,15 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> {
this.print_text(tag.to_string());
this.print_term(*values)
}),
Term::BytesType => {
self.print_text("bytes");
Ok(())
}
Term::Bytes { data } => self.print_parens(|this| {
this.print_text("bytes");
this.print_byte_string(data);
Ok(())
}),
}
}

Expand Down Expand Up @@ -717,4 +727,14 @@ impl<'p, 'a: 'p> PrintContext<'p, 'a> {
output.push('"');
self.print_text(output);
}

/// Print a bytes literal.
fn print_byte_string(&mut self, bytes: &[u8]) {
// every 3 bytes are encoded into 4 characters
let mut output = String::with_capacity(2 + bytes.len().div_ceil(3) * 4);
output.push('"');
BASE64_STANDARD.encode_string(bytes, &mut output);
output.push('"');
self.print_text(output);
}
}
5 changes: 5 additions & 0 deletions hugr-model/tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,8 @@ pub fn test_lists() {
pub fn test_const() {
binary_roundtrip(include_str!("fixtures/model-const.edn"));
}

#[test]
pub fn test_literals() {
binary_roundtrip(include_str!("fixtures/model-literals.edn"));
}
1 change: 1 addition & 0 deletions hugr-model/tests/fixtures/model-literals.edn
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
(hugr 0)

(define-alias mod.string str "\"\n\r\t\\\u{1F44D}")
(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig=="))
2 changes: 2 additions & 0 deletions hugr-model/tests/snapshots/text__literals.snap
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ expression: "roundtrip(include_str!(\"fixtures/model-literals.edn\"))"
(hugr 0)

(define-alias mod.string str "\"\n\r\t\\👍")

(define-alias mod.bytes bytes (bytes "SGVsbG8gd29ybGQg8J+Yig=="))

0 comments on commit 4b31d2b

Please sign in to comment.