From 71b6ff05db3aeb410d182ab28866a3e963de0dd2 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 7 Sep 2023 14:33:31 +0100 Subject: [PATCH] feat: Add extension deltas to CFG ops (#503) --- src/builder/build_traits.rs | 2 ++ src/builder/cfg.rs | 11 ++++++++++- src/hugr/validate.rs | 3 +++ src/ops/controlflow.rs | 13 +++++++++++++ src/ops/validate.rs | 1 + 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index f0e624132..22c7267af 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -325,6 +325,8 @@ pub trait Dataflow: Container { NodeType::open_extensions(ops::CFG { inputs: inputs.clone(), outputs: output_types.clone(), + // TODO: Make this a parameter + extension_delta: ExtensionSet::new(), }), input_wires, )?; diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index d007f04e3..027c8cab5 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -6,7 +6,10 @@ use super::{ }; use crate::ops::{self, BasicBlock, OpType}; -use crate::{extension::ExtensionRegistry, types::FunctionType}; +use crate::{ + extension::{ExtensionRegistry, ExtensionSet}, + types::FunctionType, +}; use crate::{hugr::views::HugrView, types::TypeRow}; use crate::{ops::handle::NodeHandle, types::Type}; @@ -60,6 +63,8 @@ impl CFGBuilder { let cfg_op = ops::CFG { inputs: input.clone(), outputs: output.clone(), + // TODO: Make this a parameter + extension_delta: ExtensionSet::new(), }; // TODO: Allow input extensions to be specified @@ -130,6 +135,8 @@ impl + AsRef> CFGBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), predicate_variants: predicate_variants.clone(), + // TODO: Make this a parameter + extension_delta: ExtensionSet::new(), }); let parent = self.container_node(); let block_n = if entry { @@ -277,6 +284,8 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), predicate_variants: predicate_variants.clone(), + // TODO: make this a parameter + extension_delta: ExtensionSet::new(), }; // TODO: Allow input extensions to be specified diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 5e4438fee..f45b7c5fc 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -953,6 +953,7 @@ mod test { NodeType::pure(ops::CFG { inputs: type_row![BOOL_T], outputs: type_row![BOOL_T], + extension_delta: ExtensionSet::new(), }), ); assert_matches!( @@ -969,6 +970,7 @@ mod test { inputs: type_row![BOOL_T], predicate_variants: vec![type_row![]], other_outputs: type_row![BOOL_T], + extension_delta: ExtensionSet::new(), }, ) .unwrap(); @@ -1009,6 +1011,7 @@ mod test { inputs: type_row![Q], predicate_variants: vec![type_row![]], other_outputs: type_row![Q], + extension_delta: ExtensionSet::new(), }), ); let mut block_children = b.hierarchy.children(block.index); diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 88e6ee435..74e684822 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -3,6 +3,7 @@ use smol_str::SmolStr; use crate::extension::ExtensionSet; +use crate::type_row; use crate::types::{EdgeKind, FunctionType, Type, TypeRow}; use super::dataflow::DataflowOpTrait; @@ -97,6 +98,7 @@ impl Conditional { pub struct CFG { pub inputs: TypeRow, pub outputs: TypeRow, + pub extension_delta: ExtensionSet, } impl_op_name!(CFG); @@ -110,6 +112,7 @@ impl DataflowOpTrait for CFG { fn signature(&self) -> FunctionType { FunctionType::new(self.inputs.clone(), self.outputs.clone()) + .with_extension_delta(&self.extension_delta) } } @@ -123,6 +126,7 @@ pub enum BasicBlock { inputs: TypeRow, other_outputs: TypeRow, predicate_variants: Vec, + extension_delta: ExtensionSet, }, /// The single exit node of the CFG, has no children, /// stores the types of the CFG node output. @@ -166,6 +170,15 @@ impl OpTrait for BasicBlock { fn other_output(&self) -> Option { Some(EdgeKind::ControlFlow) } + + fn signature(&self) -> FunctionType { + match self { + BasicBlock::DFB { + extension_delta, .. + } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), + BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), + } + } } impl BasicBlock { diff --git a/src/ops/validate.rs b/src/ops/validate.rs index ebf092ecd..76c6cfd45 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -343,6 +343,7 @@ impl ValidateOp for BasicBlock { inputs, predicate_variants, other_outputs: outputs, + extension_delta: _, } => { let predicate_type = Type::new_predicate(predicate_variants.clone()); let node_outputs: TypeRow = [&[predicate_type], outputs.as_ref()].concat().into();