Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Insert/extract subgraphs from a HugrView #552

Merged
merged 5 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr(hugr)?.new_root;
let node = self.add_hugr(hugr)?.new_root.unwrap();

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand All @@ -252,7 +252,7 @@ pub trait Dataflow: Container {
input_wires: impl IntoIterator<Item = Wire>,
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
let num_outputs = hugr.get_optype(hugr.root()).signature().output_count();
let node = self.add_hugr_view(hugr)?.new_root;
let node = self.add_hugr_view(hugr)?.new_root.unwrap();

let inputs = input_wires.into_iter().collect();
wire_up_inputs(inputs, node, self)?;
Expand Down
102 changes: 95 additions & 7 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::collections::HashMap;
use std::ops::Range;

use portgraph::view::{NodeFilter, NodeFiltered};
use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap};

use crate::hugr::{Direction, HugrError, HugrView, Node, NodeType};
Expand All @@ -12,6 +13,7 @@ use crate::{Hugr, Port};

use self::sealed::HugrMutInternals;

use super::views::SiblingSubgraph;
use super::{NodeMetadata, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
Expand Down Expand Up @@ -158,6 +160,23 @@ pub trait HugrMut: HugrView + HugrMutInternals {
self.hugr_mut().insert_from_view(root, other)
}

/// Copy a subgraph from another hugr into this one, under a given root node.
///
/// Sibling order is not preserved.
//
// TODO: Try to preserve the order when possible? We cannot always ensure
// it, since the subgraph may have arbitrary nodes without including their
// parent.
fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
self.valid_node(root)?;
self.hugr_mut().insert_subgraph(root, other, subgraph)
}

/// Applies a rewrite to the graph.
fn apply_rewrite<R, E>(&mut self, rw: impl Rewrite<ApplyResult = R, Error = E>) -> Result<R, E>
where
Expand All @@ -171,15 +190,21 @@ pub trait HugrMut: HugrView + HugrMutInternals {
/// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]
pub struct InsertionResult {
/// The node, after insertion, that was the root of the inserted Hugr.
/// (That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]))
pub new_root: Node,
///
/// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root]
///
/// When inserting a subgraph, this value is `None`.
pub new_root: Option<Node>,
/// Map from nodes in the Hugr/view that was inserted, to their new
/// positions in the Hugr into which said was inserted.
pub node_map: HashMap<Node, Node>,
}

impl InsertionResult {
fn translating_indices(new_root: Node, node_map: HashMap<NodeIndex, NodeIndex>) -> Self {
fn translating_indices(
new_root: Option<Node>,
node_map: HashMap<NodeIndex, NodeIndex>,
) -> Self {
Self {
new_root,
node_map: HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into()))),
Expand Down Expand Up @@ -276,10 +301,13 @@ where
let optype = other.op_types.take(node);
self.as_mut().op_types.set(new_node, optype);
let meta = other.metadata.take(node);
self.as_mut().set_metadata(node.into(), meta).unwrap();
self.as_mut().set_metadata(new_node.into(), meta).unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
}

fn insert_from_view(
Expand All @@ -294,11 +322,40 @@ where
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.get_metadata(node.into());
self.as_mut()
.set_metadata(node.into(), meta.clone())
.set_metadata(new_node.into(), meta.clone())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was done in an earlier PR, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh no, that was outline_cfg. but same error 😲

Copy link
Collaborator Author

@aborgna-q aborgna-q Sep 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, seems like an easy to miss bug :/
And we never do anything with the metadata, so the error is never triggered.

.unwrap();
}
debug_assert_eq!(Some(&other_root.index), node_map.get(&other.root().index));
Ok(InsertionResult::translating_indices(other_root, node_map))
Ok(InsertionResult::translating_indices(
Some(other_root),
node_map,
))
}

fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> Result<InsertionResult, HugrError> {
// Create a portgraph view with the explicit list of nodes defined by the subgraph.
let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
NodeFiltered::new_node_filtered(
other.portgraph(),
|node, ctx| ctx.contains(&node.into()),
subgraph.nodes(),
);
let node_map = insert_hugr_internal_with_portgraph(self.as_mut(), root, other, &portgraph)?;
// Update the optypes and metadata, copying them from the other graph.
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_nodetype(node.into());
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.get_metadata(node.into());
self.as_mut()
.set_metadata(new_node.into(), meta.clone())
.unwrap();
}
Ok(InsertionResult::translating_indices(None, node_map))
}
}

Expand Down Expand Up @@ -341,6 +398,37 @@ fn insert_hugr_internal(
Ok((other_root.into(), node_map))
}

/// Internal implementation of the `insert_subgraph` method for AsMut<Hugr>.
///
/// Returns a mapping from the nodes in the inserted graph to their new indices
/// in `hugr`.
///
/// This function does not update the optypes of the inserted nodes, so the
/// caller must do that.
///
/// In contrast to `insert_hugr_internal`, this function does not preserve
/// sibling order in the hierarchy.
fn insert_hugr_internal_with_portgraph(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this function is being defined instead of using the existing one? I assume in the existing function the graph insertion is done as a side effect, plus returning the inserted root node in the graph seems preferable to putting an Option there in the InsertionResult

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

insert_hugr_internal assumes that you are inserting a complete view. That is, a complete region with a single root node (and maybe more descendants). So we can re-connect the hierarchy of the inserted nodes simply by doing a BFS traversal from the old root.

When inserting a subgraph, we may potentially include nodes from any part of the hugr. We then preserve the hierarchy when both parent and children are being inserted, but if a node's parent is missing then we connect it to the indicated root.

Because of that, we cannot do the efficient traversal of insert_hugr_internal and instead have to iterate the inserted nodes and ask if their parent is also in the chosen set. That causes the sibling order to not be preserved (unless we add some costly pre-processing).

As for the other point, note the definition of the InsertionResult's root:

/// The node, after insertion, that was the root of the inserted Hugr.

For a subgraph there is no previous root.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! Can you add some of this to docstring please? And maybe something like insert_subgraph would be a better name since I don't think we can really say "insert hugr" if the hugr we're inserting doesn't have a root

hugr: &mut Hugr,
root: Node,
other: &impl HugrView,
portgraph: &impl portgraph::LinkView,
) -> Result<HashMap<NodeIndex, NodeIndex>, HugrError> {
let node_map = hugr.graph.insert_graph(&portgraph)?;

// A map for nodes that we inserted before their parent, so we couldn't
// update the hierarchy with their new id.
for (&node, &new_node) in node_map.iter() {
let new_parent = other
.get_parent(node.into())
.and_then(|parent| node_map.get(&parent.index).copied())
.unwrap_or(root.index);
hugr.hierarchy.push_child(new_node, new_parent)?;
}

Ok(node_map)
}

pub(crate) mod sealed {
use super::*;

Expand Down
2 changes: 1 addition & 1 deletion src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl Rewrite for OutlineCfg {
.insert_hugr(outer_cfg, new_block_bldr.hugr().clone())
.unwrap();
(
ins_res.new_root,
ins_res.new_root.unwrap(),
*ins_res.node_map.get(&cfg.node()).unwrap(),
)
};
Expand Down
76 changes: 75 additions & 1 deletion src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ use itertools::Itertools;
use portgraph::{view::Subgraph, Direction, PortView};
use thiserror::Error;

use crate::builder::{Dataflow, DataflowHugr, FunctionBuilder};
use crate::extension::{ExtensionSet, PRELUDE_REGISTRY};
use crate::hugr::{HugrError, HugrMut};
use crate::types::Signature;
use crate::{
ops::{
handle::{ContainerHandle, DataflowOpID},
Expand Down Expand Up @@ -122,7 +126,7 @@ impl SiblingSubgraph {
/// ## Definition
///
/// More formally, the sibling subgraph of a graph $G = (V, E)$ given
/// by sets of incoming and outoing boundary edges $B_I, B_O \subseteq E$
/// by sets of incoming and outgoing boundary edges $B_I, B_O \subseteq E$
/// is the graph given by the connected components of the graph
/// $G' = (V, E \ B_I \ B_O)$ that contain at least one node that is either
/// - the target of an incoming boundary edge, or
Expand Down Expand Up @@ -281,6 +285,16 @@ impl SiblingSubgraph {
self.nodes.len()
}

/// Returns the computed [`IncomingPorts`] of the subgraph.
pub fn incoming_ports(&self) -> &IncomingPorts {
&self.inputs
}

/// Returns the computed [`OutgoingPorts`] of the subgraph.
pub fn outgoing_ports(&self) -> &OutgoingPorts {
&self.outputs
}

/// The signature of the subgraph.
pub fn signature(&self, hugr: &impl HugrView) -> FunctionType {
let input = self
Expand Down Expand Up @@ -386,6 +400,51 @@ impl SiblingSubgraph {
nu_out,
))
}

/// Create a new Hugr containing only the subgraph.
///
/// The new Hugr will contain a function root wth the same signature as the
/// subgraph and the specified `input_extensions`.
pub fn extract_subgraph(
&self,
hugr: &impl HugrView,
name: impl Into<String>,
input_extensions: ExtensionSet,
) -> Result<Hugr, HugrError> {
let signature = Signature {
signature: self.signature(hugr),
input_extensions,
};
let builder = FunctionBuilder::new(name, signature).unwrap();
let inputs = builder.input_wires();
let mut extracted = builder
.finish_hugr_with_outputs(inputs, &PRELUDE_REGISTRY)
.unwrap();
let node_map = extracted
.insert_subgraph(extracted.root(), hugr, self)?
.node_map;

// Disconnect the input and output nodes, and connect the inserted nodes
// in-between.
let [inp, out] = extracted.get_io(extracted.root()).unwrap();
for (inp_port, repl_ports) in extracted
.node_ports(inp, Direction::Outgoing)
.zip(self.inputs.iter())
{
extracted.disconnect(inp, inp_port)?;
for (repl_node, repl_port) in repl_ports {
extracted.connect(inp, inp_port, node_map[repl_node], *repl_port)?;
}
}
for (out_port, (repl_node, repl_port)) in extracted
.node_ports(out, Direction::Incoming)
.zip(self.outputs.iter())
{
extracted.connect(node_map[repl_node], *repl_port, out, out_port)?;
}

Ok(extracted)
}
}

/// Precompute convexity information for a HUGR.
Expand Down Expand Up @@ -590,6 +649,8 @@ pub enum InvalidSubgraph {

#[cfg(test)]
mod tests {
use std::error::Error;

use crate::{
builder::{
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
Expand Down Expand Up @@ -821,4 +882,17 @@ mod tests {
};
assert_eq!(func_defn.signature, func.signature(&func_graph))
}

#[test]
fn extract_subgraph() -> Result<(), Box<dyn Error>> {
let (hugr, func_root) = build_hugr().unwrap();
let func_graph: SiblingGraph<'_, FuncID<true>> =
SiblingGraph::try_new(&hugr, func_root).unwrap();
let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
let extracted = subgraph.extract_subgraph(&hugr, "region", ExtensionSet::new())?;

extracted.validate(&PRELUDE_REGISTRY).unwrap();

Ok(())
}
}