From 4c1c6ee4c7d657c4bdb6b37c2237ae3f06b8d0be Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Mon, 9 Dec 2024 10:57:56 +0000 Subject: [PATCH] feat!: Make array repeat and scan ops generic over extension reqs (#1716) Closes #1714 BREAKING CHANGE: Array `scan` and `repeat` ops get an additional type parameter specifying the extension requirements of their input functions. Furthermore, `repeat` is no longer part of `ArrayOpDef` but is instead specified via a new `ArrayScan` struct. --- hugr-core/src/extension/prelude.rs | 1 + hugr-core/src/extension/prelude/array.rs | 270 ++++++++++++++++--- hugr-py/src/hugr/std/_json_defs/prelude.json | 14 +- specification/std_extensions/prelude.json | 14 +- 4 files changed, 261 insertions(+), 38 deletions(-) diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index bb6708a86..786b0379e 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -122,6 +122,7 @@ lazy_static! { NoopDef.add_to_extension(prelude, extension_ref).unwrap(); LiftDef.add_to_extension(prelude, extension_ref).unwrap(); array::ArrayOpDef::load_all_ops(prelude, extension_ref).unwrap(); + array::ArrayRepeatDef.add_to_extension(prelude, extension_ref).unwrap(); array::ArrayScanDef.add_to_extension(prelude, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index a26e2071b..3b8512fef 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -12,6 +12,7 @@ use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; use crate::extension::ExtensionId; +use crate::extension::ExtensionSet; use crate::extension::OpDef; use crate::extension::SignatureFromArgs; use crate::extension::SignatureFunc; @@ -24,6 +25,7 @@ use crate::types::FuncTypeBase; use crate::types::FuncValueType; use crate::types::RowVariable; +use crate::types::Signature; use crate::types::TypeBound; use crate::types::Type; @@ -52,7 +54,6 @@ pub enum ArrayOpDef { pop_left, pop_right, discard_empty, - repeat, } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. @@ -135,14 +136,6 @@ impl ArrayOpDef { let usize_t: Type = usize_custom_t(extension_ref).into(); match self { - repeat => { - let func = - Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone())); - PolyFuncTypeRV::new( - standard_params, - FuncValueType::new(vec![func], array_ty.clone()), - ) - } get => { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); @@ -208,10 +201,6 @@ impl MakeOpDef for ArrayOpDef { fn description(&self) -> String { match self { ArrayOpDef::new_array => "Create a new array from elements", - ArrayOpDef::repeat => { - "Creates a new array whose elements are initialised by calling \ - the given function n times" - } ArrayOpDef::get => "Get an element from an array", ArrayOpDef::set => "Set an element in an array", ArrayOpDef::swap => "Swap two elements in an array", @@ -281,7 +270,7 @@ impl MakeExtensionOp for ArrayOp { ); vec![ty_arg] } - new_array | repeat | pop_left | pop_right | get | set | swap => { + new_array | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } } @@ -347,6 +336,169 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { op.to_extension_op().unwrap() } +/// Name of the operation to repeat a value multiple times +pub const ARRAY_REPEAT_OP_ID: OpName = OpName::new_inline("repeat"); + +/// Definition of the array repeat op. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct ArrayRepeatDef; + +impl NamedOp for ArrayRepeatDef { + fn name(&self) -> OpName { + ARRAY_REPEAT_OP_ID + } +} + +impl FromStr for ArrayRepeatDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ArrayRepeatDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl ArrayRepeatDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + let params = vec![ + TypeParam::max_nat(), + TypeBound::Any.into(), + TypeParam::Extensions, + ]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t = Type::new_var_use(1, TypeBound::Any); + let es = ExtensionSet::type_var(2); + let func = + Type::new_function(Signature::new(vec![], vec![t.clone()]).with_extension_delta(es)); + let array_ty = instantiate(array_def, n, t); + PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() + } +} + +impl MakeOpDef for ArrayRepeatDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + self.signature_from_def(array_type_def()) + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) + } + + fn extension(&self) -> ExtensionId { + PRELUDE_ID + } + + fn description(&self) -> String { + "Creates a new array whose elements are initialised by calling \ + the given function n times" + .into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; + + self.post_opdef(def); + + Ok(()) + } +} + +/// Definition of the array repeat op. +#[derive(Clone, Debug, PartialEq)] +pub struct ArrayRepeat { + /// The element type of the resulting array. + pub elem_ty: Type, + /// Size of the array. + pub size: u64, + /// The extensions required by the function that generates the array elements. + pub extension_reqs: ExtensionSet, +} + +impl ArrayRepeat { + /// Creates a new array repeat op. + pub fn new(elem_ty: Type, size: u64, extension_reqs: ExtensionSet) -> Self { + ArrayRepeat { + elem_ty, + size, + extension_reqs, + } + } +} + +impl NamedOp for ArrayRepeat { + fn name(&self) -> OpName { + ARRAY_REPEAT_OP_ID + } +} + +impl MakeExtensionOp for ArrayRepeat { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = ArrayRepeatDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + TypeArg::Extensions { + es: self.extension_reqs.clone(), + }, + ] + } +} + +impl MakeRegisteredOp for ArrayRepeat { + fn extension_id(&self) -> ExtensionId { + PRELUDE_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { + &PRELUDE_REGISTRY + } +} + +impl HasDef for ArrayRepeat { + type Def = ArrayRepeatDef; +} + +impl HasConcrete for ArrayRepeatDef { + type Concrete = ArrayRepeat; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }, TypeArg::Extensions { es }] => { + Ok(ArrayRepeat::new(ty.clone(), *n, es.clone())) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + /// Name of the operation for the combined map/fold operation pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); @@ -382,20 +534,25 @@ impl ArrayScanDef { TypeBound::Any.into(), TypeBound::Any.into(), TypeParam::new_list(TypeBound::Any), + TypeParam::Extensions, ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t1 = Type::new_var_use(1, TypeBound::Any); let t2 = Type::new_var_use(2, TypeBound::Any); let s = TypeRV::new_row_var_use(3, TypeBound::Any); + let es = ExtensionSet::type_var(4); PolyFuncTypeRV::new( params, FuncTypeBase::::new( vec![ instantiate(array_def, n.clone(), t1.clone()).into(), - Type::new_function(FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - )) + Type::new_function( + FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + ) + .with_extension_delta(es), + ) .into(), s.clone(), ], @@ -457,22 +614,32 @@ impl MakeOpDef for ArrayScanDef { #[derive(Clone, Debug, PartialEq)] pub struct ArrayScan { /// The element type of the input array. - src_ty: Type, + pub src_ty: Type, /// The target element type of the output array. - tgt_ty: Type, + pub tgt_ty: Type, /// The accumulator types. - acc_tys: Vec, + pub acc_tys: Vec, /// Size of the array. - size: u64, + pub size: u64, + /// The extensions required by the scan function. + pub extension_reqs: ExtensionSet, } impl ArrayScan { - fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { + /// Creates a new array scan op. + pub fn new( + src_ty: Type, + tgt_ty: Type, + acc_tys: Vec, + size: u64, + extension_reqs: ExtensionSet, + ) -> Self { ArrayScan { src_ty, tgt_ty, acc_tys, size, + extension_reqs, } } } @@ -500,6 +667,9 @@ impl MakeExtensionOp for ArrayScan { TypeArg::Sequence { elems: self.acc_tys.clone().into_iter().map_into().collect(), }, + TypeArg::Extensions { + es: self.extension_reqs.clone(), + }, ] } } @@ -523,7 +693,7 @@ impl HasConcrete for ArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }, TypeArg::Extensions { es }] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() @@ -532,7 +702,13 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) + Ok(ArrayScan::new( + src_ty.clone(), + tgt_ty.clone(), + acc_tys?, + *n, + es.clone(), + )) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -693,11 +869,20 @@ mod tests { ); } + #[test] + fn test_repeat_def() { + let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(&PRELUDE_ID)); + let optype: OpType = op.clone().into(); + let new_op: ArrayRepeat = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + #[test] fn test_repeat() { let size = 2; let element_ty = qb_t(); - let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size); + let es = ExtensionSet::singleton(&PRELUDE_ID); + let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); let optype: OpType = op.into(); @@ -706,7 +891,10 @@ mod tests { assert_eq!( sig.io(), ( - &vec![Type::new_function(Signature::new(vec![], vec![qb_t()]))].into(), + &vec![Type::new_function( + Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) + )] + .into(), &vec![array_type(size, element_ty.clone())].into(), ) ); @@ -714,7 +902,13 @@ mod tests { #[test] fn test_scan_def() { - let op = ArrayScan::new(bool_t(), qb_t(), vec![usize_t()], 2); + let op = ArrayScan::new( + bool_t(), + qb_t(), + vec![usize_t()], + 2, + ExtensionSet::singleton(&PRELUDE_ID), + ); let optype: OpType = op.clone().into(); let new_op: ArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -725,8 +919,9 @@ mod tests { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); + let es = ExtensionSet::singleton(&PRELUDE_ID); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -735,7 +930,9 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) + Type::new_function( + Signature::new(vec![src_ty], vec![tgt_ty.clone()]).with_extension_delta(es) + ) ] .into(), &vec![array_type(size, tgt_ty)].into(), @@ -750,12 +947,14 @@ mod tests { let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); + let es = ExtensionSet::singleton(&PRELUDE_ID); let op = ArrayScan::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], size, + es.clone(), ); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -765,10 +964,13 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function(Signature::new( - vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], - vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] - )), + Type::new_function( + Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + ) + .with_extension_delta(es) + ), acc_ty1.clone(), acc_ty2.clone() ] diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index b48692b39..bc482976b 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -431,6 +431,9 @@ { "tp": "Type", "b": "A" + }, + { + "tp": "Extensions" } ], "body": { @@ -445,7 +448,9 @@ "b": "A" } ], - "extension_reqs": [] + "extension_reqs": [ + "2" + ] } ], "output": [ @@ -503,6 +508,9 @@ "tp": "Type", "b": "A" } + }, + { + "tp": "Extensions" } ], "body": { @@ -557,7 +565,9 @@ "b": "A" } ], - "extension_reqs": [] + "extension_reqs": [ + "4" + ] }, { "t": "R", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index b48692b39..bc482976b 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -431,6 +431,9 @@ { "tp": "Type", "b": "A" + }, + { + "tp": "Extensions" } ], "body": { @@ -445,7 +448,9 @@ "b": "A" } ], - "extension_reqs": [] + "extension_reqs": [ + "2" + ] } ], "output": [ @@ -503,6 +508,9 @@ "tp": "Type", "b": "A" } + }, + { + "tp": "Extensions" } ], "body": { @@ -557,7 +565,9 @@ "b": "A" } ], - "extension_reqs": [] + "extension_reqs": [ + "4" + ] }, { "t": "R",