Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Builder and HugrMut add_op_xxx default to open extensions #622

Merged
merged 13 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/algorithm/nest_cfgs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ pub(crate) mod test {
])
);
transform_cfg_to_nested(&mut IdentityCfgMap::new(rc));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(1, depth(&h, exit));
for n in [split, left, right, merge, head, tail] {
Expand Down Expand Up @@ -753,7 +753,7 @@ pub(crate) mod test {
let root = h.root();
let m = SiblingMut::<CfgID>::try_new(&mut h, root).unwrap();
transform_cfg_to_nested(&mut IdentityCfgMap::new(m));
h.validate(&PRELUDE_REGISTRY).unwrap();
h.update_validate(&PRELUDE_REGISTRY).unwrap();
assert_eq!(1, depth(&h, entry));
assert_eq!(3, depth(&h, head));
for n in [split, left, right, merge] {
Expand Down
12 changes: 6 additions & 6 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,18 @@ pub(crate) mod test {
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
signature: signature.clone(),
}));
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Input {
ops::Input {
types: signature.input,
}),
},
)
.unwrap();
hugr.add_node_with_parent(
hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::Output {
ops::Output {
types: signature.output,
}),
},
)
.unwrap();
hugr
Expand Down
5 changes: 2 additions & 3 deletions src/builder/conditional.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,9 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> ConditionalBuilder<B> {
let case_node =
// add case before any existing subsequent cases
if let Some(&sibling_node) = self.case_nodes[case + 1..].iter().flatten().next() {
// TODO: Allow this to be non-pure
self.hugr_mut().add_node_before(sibling_node, NodeType::open_extensions(case_op))?
self.hugr_mut().add_op_before(sibling_node, case_op)?
} else {
self.add_child_node(NodeType::open_extensions(case_op))?
self.add_child_op(case_op)?
};

self.case_nodes[case] = Some(case_node);
Expand Down
156 changes: 75 additions & 81 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ impl UnificationContext {
m_output,
node_type.op_signature().extension_reqs,
);
if matches!(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried a few things here - see e.g. commit 59eb260 and this comment. Think this is better than that...

but a better-still way (???) might be to define OpType::default_extensions() -> Option<ExtensionSet> and then make add_op_xxxx use not open_extensions but default. (Where default_extensions is None i.e. open for most OpTypes, but ExtensionSet::new() i.e. pure for Alias/Function/FuncDefn.) Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Could be good for a later refactor - at the same time as sorting out NodeType::pure etc. as that should be new_pure or OpType::pure(self), like FunctionType::pure(self) -> Signature)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good ideas, both imo! The first one, making "open_extensions" a more well supported "default" option was hampered previously by the lack of extension inference so now we can reconsider

node_type.tag(),
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
) {
self.add_solution(m_input, ExtensionSet::new());
}
}
// We have a solution for everything!
Some(sig) => {
Expand All @@ -338,16 +344,16 @@ impl UnificationContext {
| Some(EdgeKind::ControlFlow)
)
}) {
let m_tgt = *self
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a driveby lifting of a loop invariant, nothing more

.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
for (src_node, _) in hugr.linked_ports(tgt_node, port) {
let m_src = self
.extensions
.get(&(src_node, Direction::Outgoing))
.unwrap();
let m_tgt = self
.extensions
.get(&(tgt_node, Direction::Incoming))
.unwrap();
self.add_constraint(*m_src, Constraint::Equal(*m_tgt));
self.add_constraint(*m_src, Constraint::Equal(m_tgt));
}
}
}
Expand Down Expand Up @@ -727,11 +733,11 @@ mod test {
let root_node = NodeType::open_extensions(op);
let mut hugr = Hugr::new(root_node);

let input = NodeType::open_extensions(ops::Input::new(type_row![NAT, NAT]));
let output = NodeType::open_extensions(ops::Output::new(type_row![NAT]));
let input = ops::Input::new(type_row![NAT, NAT]);
let output = ops::Output::new(type_row![NAT]);

let input = hugr.add_node_with_parent(hugr.root(), input)?;
let output = hugr.add_node_with_parent(hugr.root(), output)?;
let input = hugr.add_op_with_parent(hugr.root(), input)?;
let output = hugr.add_op_with_parent(hugr.root(), output)?;

assert_matches!(hugr.get_io(hugr.root()), Some(_));

Expand All @@ -747,29 +753,29 @@ mod test {
let mult_c_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT])
.with_extension_delta(&ExtensionSet::singleton(&C));

let add_a = hugr.add_node_with_parent(
let add_a = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_a_sig,
}),
},
)?;
let add_b = hugr.add_node_with_parent(
let add_b = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_b_sig,
}),
},
)?;
let add_ab = hugr.add_node_with_parent(
let add_ab = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_ab_sig,
}),
},
)?;
let mult_c = hugr.add_node_with_parent(
let mult_c = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_c_sig,
}),
},
)?;

hugr.connect(input, 0, add_a, 0)?;
Expand Down Expand Up @@ -903,29 +909,26 @@ mod test {
let [input, output] = hugr.get_io(hugr.root()).unwrap();
let add_r_sig = FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs);

let add_r = hugr.add_node_with_parent(
let add_r = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: add_r_sig,
}),
},
)?;

// Dangling thingy
let src_sig = FunctionType::new(type_row![], type_row![NAT])
.with_extension_delta(&ExtensionSet::new());

let src = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG { signature: src_sig }),
)?;
let src = hugr.add_op_with_parent(hugr.root(), ops::DFG { signature: src_sig })?;

let mult_sig = FunctionType::new(type_row![NAT, NAT], type_row![NAT]);
// Mult has open extension requirements, which we should solve to be "R"
let mult = hugr.add_node_with_parent(
let mult = hugr.add_op_with_parent(
hugr.root(),
NodeType::open_extensions(ops::DFG {
ops::DFG {
signature: mult_sig,
}),
},
)?;

hugr.connect(input, 0, add_r, 0)?;
Expand Down Expand Up @@ -985,18 +988,18 @@ mod test {
) -> Result<[Node; 3], Box<dyn Error>> {
let op: OpType = op.into();

let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?;
let input = hugr.add_node_with_parent(
let node = hugr.add_op_with_parent(parent, op)?;
let input = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Input {
ops::Input {
types: op_sig.input,
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
node,
NodeType::open_extensions(ops::Output {
ops::Output {
types: op_sig.output,
}),
},
)?;
Ok([node, input, output])
}
Expand All @@ -1017,20 +1020,20 @@ mod test {
Into::<OpType>::into(op).signature(),
)?;

let lift1 = hugr.add_node_with_parent(
let lift1 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: first_ext,
}),
},
)?;

let lift2 = hugr.add_node_with_parent(
let lift2 = hugr.add_op_with_parent(
case,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: second_ext,
}),
},
)?;

hugr.connect(case_in, 0, lift1, 0)?;
Expand Down Expand Up @@ -1095,17 +1098,17 @@ mod test {
}));

let root = hugr.root();
let input = hugr.add_node_with_parent(
let input = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Input {
ops::Input {
types: type_row![NAT],
}),
},
)?;
let output = hugr.add_node_with_parent(
let output = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::Output {
ops::Output {
types: type_row![NAT],
}),
},
)?;

// Make identical dataflow nodes which add extension requirement "A" or "B"
Expand All @@ -1126,12 +1129,12 @@ mod test {
.unwrap();

let lift = hugr
.add_node_with_parent(
.add_op_with_parent(
node,
NodeType::open_extensions(ops::LeafOp::Lift {
ops::LeafOp::Lift {
type_row: type_row![NAT],
new_extension: ext,
}),
},
)
.unwrap();

Expand Down Expand Up @@ -1178,7 +1181,7 @@ mod test {

let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?;

let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?;
let dfg = hugr.add_op_with_parent(bb, op)?;

hugr.connect(bb_in, 0, dfg, 0)?;
hugr.connect(dfg, 0, bb_out, 0)?;
Expand Down Expand Up @@ -1210,23 +1213,20 @@ mod test {
extension_delta: entry_extensions,
};

let exit = hugr.add_node_with_parent(
let exit = hugr.add_op_with_parent(
root,
NodeType::open_extensions(ops::BasicBlock::Exit {
ops::BasicBlock::Exit {
cfg_outputs: exit_types.into(),
}),
},
)?;

let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?;
let entry_in = hugr.add_node_with_parent(
let entry = hugr.add_op_before(exit, dfb)?;
let entry_in = hugr.add_op_with_parent(entry, ops::Input { types: inputs })?;
let entry_out = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(ops::Input { types: inputs }),
)?;
let entry_out = hugr.add_node_with_parent(
entry,
NodeType::open_extensions(ops::Output {
ops::Output {
types: vec![entry_tuple_sum].into(),
}),
},
)?;

Ok(([entry, entry_in, entry_out], exit))
Expand Down Expand Up @@ -1277,12 +1277,12 @@ mod test {
type_row![NAT],
)?;

let mkpred = hugr.add_node_with_parent(
let mkpred = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
A,
FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a),
)),
),
)?;

// Internal wiring for DFGs
Expand Down Expand Up @@ -1373,12 +1373,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], twoway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], twoway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down Expand Up @@ -1462,12 +1459,12 @@ mod test {
type_row![NAT],
)?;

let entry_dfg = hugr.add_node_with_parent(
let entry_dfg = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext),
)),
),
)?;

hugr.connect(entry_in, 0, entry_dfg, 0)?;
Expand Down Expand Up @@ -1543,12 +1540,9 @@ mod test {
type_row![NAT],
)?;

let entry_mid = hugr.add_node_with_parent(
let entry_mid = hugr.add_op_with_parent(
entry,
NodeType::open_extensions(make_opaque(
UNKNOWN_EXTENSION,
FunctionType::new(vec![NAT], oneway(NAT)),
)),
make_opaque(UNKNOWN_EXTENSION, FunctionType::new(vec![NAT], oneway(NAT))),
)?;

hugr.connect(entry_in, 0, entry_mid, 0)?;
Expand Down
Loading