Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extension inference for conditional nodes #465

Merged
merged 7 commits into from
Sep 1, 2023
152 changes: 137 additions & 15 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
//! will succeed regardless of what the variable is instantiated to.

use super::{ExtensionId, ExtensionSet};
use crate::{hugr::views::HugrView, hugr::Node, ops::OpType, types::EdgeKind, Direction};
use crate::{
hugr::views::HugrView,
hugr::Node,
ops::{OpTag, OpTrait, OpType},
types::EdgeKind,
Direction,
};

use super::validate::ExtensionError;

Expand Down Expand Up @@ -286,6 +292,15 @@ impl UnificationContext {
}
}

if hugr.get_optype(node).tag() == OpTag::Conditional {
for case in hugr.children(node) {
let m_case_in = self.make_or_get_meta(case, Direction::Incoming);
let m_case_out = self.make_or_get_meta(case, Direction::Outgoing);
self.add_constraint(m_case_in, Constraint::Equal(m_input));
self.add_constraint(m_case_out, Constraint::Equal(m_output));
}
}

match node_type.signature() {
// Input extensions are open
None => {
Expand Down Expand Up @@ -654,22 +669,22 @@ mod test {
use crate::extension::{ExtensionSet, EMPTY_REG};
use crate::hugr::HugrMut;
use crate::hugr::{validate::ValidationError, Hugr, HugrView, NodeType};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle};
use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait};
use crate::type_row;
use crate::types::{FunctionType, Type};

use cool_asserts::assert_matches;
use portgraph::NodeIndex;

const BIT: Type = crate::extension::prelude::USIZE_T;
const NAT: Type = crate::extension::prelude::USIZE_T;

#[test]
// Build up a graph with some holes in its extension requirements, and infer
// them.
fn from_graph() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::from_iter(["A".into(), "B".into(), "C".into()]);
let main_sig =
FunctionType::new(type_row![BIT, BIT], type_row![BIT]).with_extension_delta(&rs);
FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_extension_delta(&rs);

let op = ops::DFG {
signature: main_sig,
Expand All @@ -678,24 +693,24 @@ mod test {
let root_node = NodeType::open_extensions(op);
let mut hugr = Hugr::new(root_node);

let input = NodeType::open_extensions(ops::Input::new(type_row![BIT, BIT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![BIT]));
let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));

let input = hugr.add_node_with_parent(hugr.root(), input)?;
let output = hugr.add_node_with_parent(hugr.root(), output)?;

assert_matches!(hugr.get_io(hugr.root()), Some(_));

let add_a_sig = FunctionType::new(type_row![BIT], type_row![BIT])
let add_a_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"A".into()));

let add_b_sig = FunctionType::new(type_row![BIT], type_row![BIT])
let add_b_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"B".into()));

let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT])
let add_ab_sig = FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::from_iter(["A".into(), "B".into()]));

let mult_c_sig = FunctionType::new(type_row![BIT, BIT], type_row![BIT])
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"C".into()));

let add_a = hugr.add_node_with_parent(
Expand Down Expand Up @@ -801,7 +816,7 @@ mod test {
// because of a missing lift node
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
let builder = DFGBuilder::new(
FunctionType::new(type_row![BIT], type_row![BIT])
FunctionType::new(type_row![NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&"R".into())),
)?;
let [w] = builder.input_wires_arr();
Expand Down Expand Up @@ -848,11 +863,11 @@ mod test {
fn dangling_src() -> Result<(), Box<dyn Error>> {
let rs = ExtensionSet::singleton(&"R".into());
let root_signature =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&rs);
FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);
let mut builder = DFGBuilder::new(root_signature)?;
let [input_wire] = builder.input_wires_arr();

let add_r_sig = FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(&rs);
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let add_r = builder.add_dataflow_node(
NodeType::open_extensions(ops::DFG {
Expand All @@ -863,15 +878,15 @@ mod test {
let [wl] = add_r.outputs_arr();

// Dangling thingy
let src_sig = FunctionType::new(type_row![], type_row![BIT])
let src_sig = FunctionType::new(type_row![], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());
let src = builder.add_dataflow_node(
NodeType::open_extensions(ops::DFG { signature: src_sig }),
[],
)?;
let [wr] = src.outputs_arr();

let mult_sig = FunctionType::new(type_row![BIT, BIT], type_row![BIT])
let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());
// Mult has open extension requirements, which we should solve to be "R"
let mult = builder.add_dataflow_node(
Expand Down Expand Up @@ -993,4 +1008,111 @@ mod test {

Ok(())
}

fn create_with_io(
hugr: &mut Hugr,
parent: Node,
op: impl Into<OpType>,
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();
let input_types = op.signature().input;
let output_types = op.signature().output;

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
node,
NodeType::open_extensions(ops::Input { types: input_types }),
)?;
let output = hugr.add_node_with_parent(
node,
NodeType::open_extensions(ops::Output {
types: output_types,
}),
)?;
Ok([node, input, output])
}

#[test]
fn test_conditional_inference() -> Result<(), Box<dyn Error>> {
fn build_case(
hugr: &mut Hugr,
conditional_node: Node,
op: ops::Case,
first_ext: ExtensionId,
second_ext: ExtensionId,
) -> Result<Node, Box<dyn Error>> {
let [case, case_in, case_out] = create_with_io(hugr, conditional_node, op)?;

let lift1 = hugr.add_node_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
)?;

let lift2 = hugr.add_node_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
)?;

hugr.connect(case_in, 0, lift1, 0)?;
hugr.connect(lift1, 0, lift2, 0)?;
hugr.connect(lift2, 0, case_out, 0)?;

Ok(case)
}

let predicate_inputs = vec![type_row![]; 2];
let rs = ExtensionSet::from_iter(["A".into(), "B".into()]);

let inputs = type_row![NAT];
let outputs = type_row![NAT];

let op = ops::Conditional {
predicate_inputs,
other_inputs: inputs.clone(),
outputs: outputs.clone(),
extension_delta: rs.clone(),
};

let mut hugr = Hugr::new(NodeType::pure(op));
let conditional_node = hugr.root();

let case_op = ops::Case {
signature: FunctionType::new(inputs, outputs).with_extension_delta(&rs),
};
let case0_node = build_case(
&mut hugr,
conditional_node,
case_op.clone(),
"A".into(),
"B".into(),
)?;

let case1_node = build_case(&mut hugr, conditional_node, case_op, "B".into(), "A".into())?;

hugr.infer_extensions()?;

for node in [case0_node, case1_node, conditional_node] {
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
assert_eq!(
hugr.get_nodetype(node)
.signature()
.unwrap()
.input_extensions,
ExtensionSet::new()
);
}
Ok(())
}
}