Skip to content

Commit

Permalink
feat: add ArrayValue to python, rust and lowering (#1773)
Browse files Browse the repository at this point in the history
drive-by: fix Array type annotation

Closes #1771 

TODO:
 - [x] Rust ArrayValue
 - [x] LLVM lowering

---------

Co-authored-by: Mark Koch <[email protected]>
  • Loading branch information
ss2165 and mark-koch authored Dec 17, 2024
1 parent b310fac commit d429cff
Show file tree
Hide file tree
Showing 7 changed files with 288 additions and 14 deletions.
177 changes: 167 additions & 10 deletions hugr-core/src/std_extensions/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@ mod array_scan;

use std::sync::Arc;

use itertools::Itertools as _;
use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};

use crate::extension::resolution::{
resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError,
WeakExtensionRegistry,
};
use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp};
use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound};
use crate::ops::{ExtensionOp, OpName};
use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound};
use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName};
use crate::ops::{ExtensionOp, OpName, Value};
use crate::types::type_param::{TypeArg, TypeParam};
use crate::types::{Type, TypeBound, TypeName};
use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName};
use crate::Extension;

pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter};
Expand All @@ -26,8 +34,128 @@ pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.ar
/// Extension version.
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
/// Statically sized array of values, all of the same type.
pub struct ArrayValue {
values: Vec<Value>,
typ: Type,
}

impl ArrayValue {
/// Create a new [CustomConst] for an array of values of type `typ`.
/// That all values are of type `typ` is not checked here.
pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
Self {
values: contents.into_iter().collect_vec(),
typ,
}
}

/// Create a new [CustomConst] for an empty array of values of type `typ`.
pub fn new_empty(typ: Type) -> Self {
Self {
values: vec![],
typ,
}
}

/// Returns the type of the `[ArrayValue]` as a `[CustomType]`.`
pub fn custom_type(&self) -> CustomType {
array_custom_type(self.values.len() as u64, self.typ.clone())
}

/// Returns the type of values inside the `[ArrayValue]`.
pub fn get_element_type(&self) -> &Type {
&self.typ
}

/// Returns the values contained inside the `[ArrayValue]`.
pub fn get_contents(&self) -> &[Value] {
&self.values
}
}

impl TryHash for ArrayValue {
fn try_hash(&self, mut st: &mut dyn Hasher) -> bool {
maybe_hash_values(&self.values, &mut st) && {
self.typ.hash(&mut st);
true
}
}
}

#[typetag::serde]
impl CustomConst for ArrayValue {
fn name(&self) -> ValueName {
ValueName::new_inline("array")
}

fn get_type(&self) -> Type {
self.custom_type().into()
}

fn validate(&self) -> Result<(), CustomCheckFailure> {
let typ = self.custom_type();

EXTENSION
.get_type(&ARRAY_TYPENAME)
.unwrap()
.check_custom(&typ)
.map_err(|_| {
CustomCheckFailure::Message(format!(
"Custom typ {typ} is not a valid instantiation of array."
))
})?;

// constant can only hold classic type.
let ty = match typ.args() {
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }]
if *n as usize == self.values.len() =>
{
ty
}
_ => {
return Err(CustomCheckFailure::Message(format!(
"Invalid array type arguments: {:?}",
typ.args()
)))
}
};

// check all values are instances of the element type
for v in &self.values {
if v.get_type() != *ty {
return Err(CustomCheckFailure::Message(format!(
"Array element {v:?} is not of expected type {ty}"
)));
}
}

Ok(())
}

fn equal_consts(&self, other: &dyn CustomConst) -> bool {
crate::ops::constant::downcast_equal_consts(self, other)
}

fn extension_reqs(&self) -> ExtensionSet {
ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs))
.union(EXTENSION_ID.into())
}

fn update_extensions(
&mut self,
extensions: &WeakExtensionRegistry,
) -> Result<(), ExtensionResolutionError> {
for val in &mut self.values {
resolve_value_extensions(val, extensions)?;
}
resolve_type_extensions(&mut self.typ, extensions)
}
}

lazy_static! {
/// Extension for list operations.
/// Extension for array operations.
pub static ref EXTENSION: Arc<Extension> = {
Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| {
extension.add_type(
Expand Down Expand Up @@ -55,7 +183,7 @@ fn array_type_def() -> &'static TypeDef {
/// This method is equivalent to [`array_type_parametric`], but uses concrete
/// arguments types to ensure no errors are possible.
pub fn array_type(size: u64, element_ty: Type) -> Type {
instantiate_array(array_type_def(), size, element_ty).expect("array parameters are valid")
array_custom_type(size, element_ty).into()
}

/// Instantiate a new array type given the size and element type parameters.
Expand All @@ -68,14 +196,25 @@ pub fn array_type_parametric(
instantiate_array(array_type_def(), size, element_ty)
}

fn array_custom_type(size: impl Into<TypeArg>, element_ty: impl Into<TypeArg>) -> CustomType {
instantiate_array_custom(array_type_def(), size, element_ty)
.expect("array parameters are valid")
}

fn instantiate_array_custom(
array_def: &TypeDef,
size: impl Into<TypeArg>,
element_ty: impl Into<TypeArg>,
) -> Result<CustomType, SignatureError> {
array_def.instantiate(vec![size.into(), element_ty.into()])
}

fn instantiate_array(
array_def: &TypeDef,
size: impl Into<TypeArg>,
element_ty: impl Into<TypeArg>,
) -> Result<Type, SignatureError> {
array_def
.instantiate(vec![size.into(), element_ty.into()])
.map(Into::into)
instantiate_array_custom(array_def, size, element_ty).map(Into::into)
}

/// Name of the operation in the prelude for creating new arrays.
Expand All @@ -90,9 +229,11 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
#[cfg(test)]
mod test {
use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr};
use crate::extension::prelude::qb_t;
use crate::extension::prelude::{qb_t, usize_t, ConstUsize};
use crate::ops::constant::CustomConst;
use crate::std_extensions::arithmetic::float_types::ConstF64;

use super::{array_type, new_array_op};
use super::{array_type, new_array_op, ArrayValue};

#[test]
/// Test building a HUGR involving a new_array operation.
Expand All @@ -108,4 +249,20 @@ mod test {

b.finish_hugr_with_outputs(out.outputs()).unwrap();
}

#[test]
fn test_array_value() {
let array_value = ArrayValue {
values: vec![ConstUsize::new(3).into()],
typ: usize_t(),
};

array_value.validate().unwrap();

let wrong_array_value = ArrayValue {
values: vec![ConstF64::new(1.2).into()],
typ: usize_t(),
};
assert!(wrong_array_value.validate().is_err());
}
}
2 changes: 1 addition & 1 deletion hugr-core/src/std_extensions/collections/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct ListValue(Vec<Value>, Type);

impl ListValue {
/// Create a new [CustomConst] for a list of values of type `typ`.
/// That all values ore of type `typ` is not checked here.
/// That all values are of type `typ` is not checked here.
pub fn new(typ: Type, contents: impl IntoIterator<Item = Value>) -> Self {
Self(contents.into_iter().collect_vec(), typ)
}
Expand Down
56 changes: 56 additions & 0 deletions hugr-llvm/src/extension/collections/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use inkwell::values::{
use inkwell::IntPredicate;
use itertools::Itertools;

use crate::emit::emit_value;
use crate::{
emit::{deaggregate_call_result, EmitFuncContext, RowPromise},
sum::LLVMSumType,
Expand Down Expand Up @@ -52,6 +53,15 @@ pub trait ArrayCodegen: Clone {
elem_ty.array_type(size as u32)
}

/// Emit a [hugr_core::std_extensions::collections::array::ArrayValue].
fn emit_array_value<'c, H: HugrView>(
&self,
ctx: &mut EmitFuncContext<'c, '_, H>,
value: &array::ArrayValue,
) -> Result<BasicValueEnum<'c>> {
emit_array_value(self, ctx, value)
}

/// Emit a [hugr_core::std_extensions::collections::array::ArrayOp].
fn emit_array_op<'c, H: HugrView>(
&self,
Expand Down Expand Up @@ -129,6 +139,10 @@ impl<CCG: ArrayCodegen> CodegenExtension for ArrayCodegenExtension<CCG> {
Ok(ccg.array_type(&ts, elem_ty, *n).as_basic_type_enum())
}
})
.custom_const::<array::ArrayValue>({
let ccg = self.0.clone();
move |context, k| ccg.emit_array_value(context, k)
})
.simple_extension_op::<ArrayOpDef>({
let ccg = self.0.clone();
move |context, args, _| {
Expand Down Expand Up @@ -244,6 +258,31 @@ fn build_loop<'c, T, H: HugrView>(
Ok(val)
}

pub fn emit_array_value<'c, H: HugrView>(
ccg: &impl ArrayCodegen,
ctx: &mut EmitFuncContext<'c, '_, H>,
value: &array::ArrayValue,
) -> Result<BasicValueEnum<'c>> {
let ts = ctx.typing_session();
let llvm_array_ty = ccg
.array_type(
&ts,
ts.llvm_type(value.get_element_type())?,
value.get_contents().len() as u64,
)
.as_basic_type_enum()
.into_array_type();
let mut array_v = llvm_array_ty.get_undef();
for (i, v) in value.get_contents().iter().enumerate() {
let llvm_v = emit_value(ctx, v)?;
array_v = ctx
.builder()
.build_insert_value(array_v, llvm_v, i as u32, "")?
.into_array_value();
}
Ok(array_v.into())
}

pub fn emit_array_op<'c, H: HugrView>(
ccg: &impl ArrayCodegen,
ctx: &mut EmitFuncContext<'c, '_, H>,
Expand Down Expand Up @@ -739,6 +778,23 @@ mod test {
check_emission!(hugr, llvm_ctx);
}

#[rstest]
fn emit_array_value(mut llvm_ctx: TestContext) {
let hugr = SimpleHugrConfig::new()
.with_extensions(STD_REG.to_owned())
.with_outs(vec![array_type(2, usize_t())])
.finish(|mut builder| {
let vs = vec![ConstUsize::new(1).into(), ConstUsize::new(2).into()];
let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), vs));
builder.finish_with_outputs([arr]).unwrap()
});
llvm_ctx.add_extensions(|cge| {
cge.add_default_prelude_extensions()
.add_default_array_extensions()
});
check_emission!(hugr, llvm_ctx);
}

fn exec_registry() -> ExtensionRegistry {
ExtensionRegistry::new([
int_types::EXTENSION.to_owned(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
---
source: hugr-llvm/src/extension/collections/array.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define [2 x i64] @_hl.main.1() {
alloca_block:
br label %entry_block

entry_block: ; preds = %alloca_block
ret [2 x i64] [i64 1, i64 2]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
source: hugr-llvm/src/extension/collections/array.rs
expression: mod_str
---
; ModuleID = 'test_context'
source_filename = "test_context"

define [2 x i64] @_hl.main.1() {
alloca_block:
%"0" = alloca [2 x i64], align 8
%"5_0" = alloca [2 x i64], align 8
br label %entry_block

entry_block: ; preds = %alloca_block
store [2 x i64] [i64 1, i64 2], [2 x i64]* %"5_0", align 4
%"5_01" = load [2 x i64], [2 x i64]* %"5_0", align 4
store [2 x i64] %"5_01", [2 x i64]* %"0", align 4
%"02" = load [2 x i64], [2 x i64]* %"0", align 4
ret [2 x i64] %"02"
}
Loading

0 comments on commit d429cff

Please sign in to comment.