Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Better subgraph verification errors #587

Merged
merged 2 commits into from
Oct 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 82 additions & 39 deletions src/hugr/views/sibling_subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use pyo3::{create_exception, exceptions::PyException, PyErr};
/// [`super::SiblingGraph`], not all nodes of the sibling graph must be
/// included. A convex subgraph is always an induced subgraph, i.e. it is defined
/// by a set of nodes and all edges between them.

///
/// The incoming boundary (resp. outgoing boundary) is given by the input (resp.
/// output) ports of the subgraph that are linked to nodes outside of the subgraph.
/// The signature of the subgraph is then given by the types of the incoming
Expand Down Expand Up @@ -485,6 +485,9 @@ fn validate_subgraph<H: HugrView>(
inputs: &IncomingPorts,
outputs: &OutgoingPorts,
) -> Result<(), InvalidSubgraph> {
// Copy of the nodes for fast lookup.
let node_set = nodes.iter().copied().collect::<HashSet<_>>();

// Check nodes is not empty
if nodes.is_empty() {
return Err(InvalidSubgraph::EmptySubgraph);
Expand All @@ -501,76 +504,84 @@ fn validate_subgraph<H: HugrView>(
.chain(outputs)
.any(|&(n, p)| is_order_edge(hugr, n, p))
{
unimplemented!("Linked other ports not supported at boundary")
unimplemented!("Connected order edges not supported at the boundary")
}

// Check inputs are incoming ports and outputs are outgoing ports
if inputs
if let Some(&(n, p)) = inputs
.iter()
.flatten()
.any(|(_, p)| p.direction() == Direction::Outgoing)
.find(|(_, p)| p.direction() == Direction::Outgoing)
{
return Err(InvalidSubgraph::InvalidBoundary);
}
if outputs
Err(InvalidSubgraphBoundary::InputPortDirection(n, p))?;
};
if let Some(&(n, p)) = outputs
.iter()
.any(|(_, p)| p.direction() == Direction::Incoming)
.find(|(_, p)| p.direction() == Direction::Incoming)
{
return Err(InvalidSubgraph::InvalidBoundary);
}
Err(InvalidSubgraphBoundary::OutputPortDirection(n, p))?;
};

let mut ports_inside = inputs.iter().flatten().chain(outputs).copied();
// Check incoming & outgoing ports have target resp. source inside
let nodes = nodes.iter().copied().collect::<HashSet<_>>();
if ports_inside.any(|(n, _)| !nodes.contains(&n)) {
return Err(InvalidSubgraph::InvalidBoundary);
}
let boundary_ports = inputs
.iter()
.flatten()
.chain(outputs)
.copied()
.collect_vec();
// Check that the boundary ports are all in the subgraph.
if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
};
// Check that every inside port has at least one linked port outside.
if ports_inside.any(|(n, p)| hugr.linked_ports(n, p).all(|(n1, _)| nodes.contains(&n1))) {
return Err(InvalidSubgraph::InvalidBoundary);
}
if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
hugr.linked_ports(n, p)
.all(|(n1, _)| node_set.contains(&n1))
}) {
Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
};

// Check that every incoming port of a node in the subgraph whose source is not in the subgraph
// belongs to inputs.
if nodes.clone().into_iter().any(|n| {
if nodes.iter().any(|&n| {
hugr.node_inputs(n).any(|p| {
hugr.linked_ports(n, p).any(|(n1, _)| {
!nodes.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
!node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
})
})
}) {
return Err(InvalidSubgraph::NotConvex);
}
// Check that every outgoing port of a node in the subgraph whose target is not in the subgraph
// belongs to outputs.
if nodes.clone().into_iter().any(|n| {
if nodes.iter().any(|&n| {
hugr.node_outputs(n).any(|p| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !nodes.contains(&n1) && !outputs.contains(&(n, p)))
.any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
})
}) {
return Err(InvalidSubgraph::NotConvex);
}

// Check inputs are unique
if !inputs.iter().flatten().all_unique() {
return Err(InvalidSubgraph::InvalidBoundary);
return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
}

// Check no incoming partition is empty
if inputs.iter().any(|p| p.is_empty()) {
return Err(InvalidSubgraph::InvalidBoundary);
return Err(InvalidSubgraphBoundary::EmptyPartition.into());
}

// Check edge types are equal within partition and copyable if partition size > 1
if !inputs.iter().all(|ports| {
if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
let Some(edge_t) = get_edge_type(hugr, ports) else {
return false;
return true;
};
let require_copy = ports.len() > 1;
!require_copy || edge_t.copyable()
require_copy && !edge_t.copyable()
}) {
return Err(InvalidSubgraph::InvalidBoundary);
}
Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
};

Ok(())
}
Expand Down Expand Up @@ -663,13 +674,41 @@ pub enum InvalidSubgraph {
EmptySubgraph,
/// An invalid boundary port was found.
#[error("Invalid boundary port.")]
InvalidBoundary,
InvalidBoundary(#[from] InvalidSubgraphBoundary),
}

/// Errors that can occur while constructing a [`SiblingSubgraph`].
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum InvalidSubgraphBoundary {
/// A node in the input boundary is not Incoming.
#[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an incoming port.")]
InputPortDirection(Node, Port),
/// A node in the output boundary is not Outgoing.
#[error("Expected (node {0:?}, port {1:?}) in the input boundary to be an outgoing port.")]
OutputPortDirection(Node, Port),
/// A boundary port's node is not in the set of nodes.
#[error("(node {0:?}, port {1:?}) is in the boundary, but node {0:?} is not in the set.")]
PortNodeNotInSet(Node, Port),
/// A boundary port has no connections outside the subgraph.
#[error("(node {0:?}, port {1:?}) is in the boundary, but the port is not connected to a node outside the subgraph.")]
DisconnectedBoundaryPort(Node, Port),
/// There's a non-unique input-boundary port.
#[error("A port in the input boundary is used multiple times.")]
NonUniqueInput,
/// There's an empty partition in the input boundary.
#[error("A partition in the input boundary is empty.")]
EmptyPartition,
/// Different types in a partition of the input boundary.
#[error("The partition {0} in the input boundary has ports with different types.")]
MismatchedTypes(usize),
}

#[cfg(test)]
mod tests {
use std::error::Error;

use cool_asserts::assert_matches;

use crate::extension::PRELUDE_REGISTRY;
use crate::{
builder::{
Expand Down Expand Up @@ -883,14 +922,16 @@ mod tests {
let (inp, _) = hugr.children(func_root).take(2).collect_tuple().unwrap();
let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
// All graph but one edge
assert!(matches!(
assert_matches!(
SiblingSubgraph::try_new(
vec![hugr.linked_ports(inp, first_cx_edge).collect()],
vec![(inp, first_cx_edge)],
&func,
),
Err(InvalidSubgraph::NotConvex)
));
Err(InvalidSubgraph::InvalidBoundary(
InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
))
);
}

#[test]
Expand All @@ -905,14 +946,14 @@ mod tests {
let not1_out = hugr.node_outputs(not1).next().unwrap();
let not3_inp = hugr.node_inputs(not3).next().unwrap();
let not3_out = hugr.node_outputs(not3).next().unwrap();
assert!(matches!(
assert_matches!(
SiblingSubgraph::try_new(
vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
vec![(not1, not1_out), (not3, not3_out)],
&func
),
Err(InvalidSubgraph::NotConvex)
));
);
}

#[test]
Expand All @@ -923,14 +964,16 @@ mod tests {
let cx_edges_in = hugr.node_outputs(inp);
let cx_edges_out = hugr.node_inputs(out);
// All graph but the CX
assert!(matches!(
assert_matches!(
SiblingSubgraph::try_new(
cx_edges_out.map(|p| vec![(out, p)]).collect(),
cx_edges_in.map(|p| (inp, p)).collect(),
&func,
),
Err(InvalidSubgraph::InvalidBoundary)
));
Err(InvalidSubgraph::InvalidBoundary(
InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
))
);
}

#[test]
Expand Down