Skip to content

Commit

Permalink
feat: PortIndex trait for undirected port parameters (#553)
Browse files Browse the repository at this point in the history
Add a `PortIndex` trait implemented by both `Port` and `usize`, to use
in all the methods that expect an undirected port offset (instead of
asking for an `usize`).
  • Loading branch information
aborgna-q committed Sep 22, 2023
1 parent dbb94d3 commit ffa2f0c
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 49 deletions.
16 changes: 6 additions & 10 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{Node, NodeMetadata, Port, ValidationError};
use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};

use std::iter;
Expand Down Expand Up @@ -649,13 +649,7 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
data_builder: &mut T,
) -> Result<(), BuildError> {
for (dst_port, wire) in inputs.into_iter().enumerate() {
wire_up(
data_builder,
wire.node(),
wire.source().index(),
op_node,
dst_port,
)?;
wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?;
}
Ok(())
}
Expand All @@ -664,10 +658,12 @@ fn wire_up_inputs<T: Dataflow + ?Sized>(
fn wire_up<T: Dataflow + ?Sized>(
data_builder: &mut T,
src: Node,
src_port: usize,
src_port: impl PortIndex,
dst: Node,
dst_port: usize,
dst_port: impl PortIndex,
) -> Result<bool, BuildError> {
let src_port = src_port.index();
let dst_port = dst_port.index();
let base = data_builder.hugr_mut();
let src_offset = Port::new_outgoing(src_port);

Expand Down
21 changes: 19 additions & 2 deletions src/hugr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@ pub struct Port {
offset: portgraph::PortOffset,
}

/// A trait for getting the undirected index of a port.
///
/// This allows functions to admit both [`Port`]s and explicit `usize`s for
/// identifying port offsets.
pub trait PortIndex {
/// Returns the offset of the port.
fn index(self) -> usize;
}

/// The direction of a port.
pub type Direction = portgraph::Direction;

Expand Down Expand Up @@ -382,14 +391,22 @@ impl Port {
pub fn direction(self) -> Direction {
self.offset.direction()
}
}

/// Returns the offset of the port.
impl PortIndex for Port {
#[inline(always)]
pub fn index(self) -> usize {
fn index(self) -> usize {
self.offset.index()
}
}

impl PortIndex for usize {
#[inline(always)]
fn index(self) -> usize {
self
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
/// A DataFlow wire, defined by a Value-kind output port of a node
// Stores node and offset to output port
Expand Down
14 changes: 7 additions & 7 deletions src/hugr/hugrmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{Hugr, Port};

use self::sealed::HugrMutInternals;

use super::{NodeMetadata, Rewrite};
use super::{NodeMetadata, PortIndex, Rewrite};

/// Functions for low-level building of a HUGR.
pub trait HugrMut: HugrView + HugrMutInternals {
Expand Down Expand Up @@ -108,9 +108,9 @@ pub trait HugrMut: HugrView + HugrMutInternals {
fn connect(
&mut self,
src: Node,
src_port: usize,
src_port: impl PortIndex,
dst: Node,
dst_port: usize,
dst_port: impl PortIndex,
) -> Result<(), HugrError> {
self.valid_node(src)?;
self.valid_node(dst)?;
Expand Down Expand Up @@ -234,13 +234,13 @@ where
fn connect(
&mut self,
src: Node,
src_port: usize,
src_port: impl PortIndex,
dst: Node,
dst_port: usize,
dst_port: impl PortIndex,
) -> Result<(), HugrError> {
self.as_mut()
.graph
.link_nodes(src.index, src_port, dst.index, dst_port)?;
.link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?;
Ok(())
}

Expand All @@ -265,7 +265,7 @@ where
.get_optype(dst)
.other_port_index(Direction::Incoming)
.expect("Destination operation has no non-dataflow incoming edges");
self.connect(src, src_port.index(), dst, dst_port.index())?;
self.connect(src, src_port, dst, dst_port)?;
Ok((src_port, dst_port))
}

Expand Down
4 changes: 2 additions & 2 deletions src/hugr/rewrite/insert_identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ impl Rewrite for IdentityInsertion {
let new_node = h
.add_op_with_parent(parent, LeafOp::Noop { ty })
.expect("Parent validity already checked.");
h.connect(pre_node, pre_port.index(), new_node, 0)
h.connect(pre_node, pre_port, new_node, 0)
.expect("Should only fail if ports don't exist.");

h.connect(new_node, 0, self.post_node, self.post_port.index())
h.connect(new_node, 0, self.post_node, self.post_port)
.expect("Should only fail if ports don't exist.");
Ok(new_node)
}
Expand Down
8 changes: 3 additions & 5 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::extension::ExtensionSet;
use crate::hugr::hugrmut::sealed::HugrMutInternals;
use crate::hugr::rewrite::Rewrite;
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::{HugrMut, HugrView};
use crate::hugr::{HugrMut, HugrView, PortIndex};
use crate::ops;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpTrait, OpType};
Expand Down Expand Up @@ -155,7 +155,7 @@ impl Rewrite for OutlineCfg {
for (pred, br) in preds {
if !self.blocks.contains(&pred) {
h.disconnect(pred, br).unwrap();
h.connect(pred, br.index(), new_block, 0).unwrap();
h.connect(pred, br, new_block, 0).unwrap();
}
}
if entry == outer_entry {
Expand Down Expand Up @@ -204,9 +204,7 @@ impl Rewrite for OutlineCfg {
SiblingMut::try_new(h, new_block).unwrap();
let mut in_cfg_view: SiblingMut<'_, CfgID> =
SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap();
in_cfg_view
.connect(exit, exit_port.index(), inner_exit, 0)
.unwrap();
in_cfg_view.connect(exit, exit_port, inner_exit, 0).unwrap();

Ok(())
}
Expand Down
14 changes: 7 additions & 7 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl Rewrite for SimpleReplacement {
for target in self.replacement.linked_ports(node, outport) {
if self.replacement.get_optype(target.0).tag() != OpTag::Output {
let new_target = index_map.get(&target.0).unwrap();
h.connect(*new_node, outport.index(), *new_target, target.1.index())
h.connect(*new_node, outport, *new_target, target.1)
.unwrap();
}
}
Expand All @@ -127,9 +127,9 @@ impl Rewrite for SimpleReplacement {
let new_inp_node = index_map.get(rep_inp_node).unwrap();
h.connect(
rem_inp_pred_node,
rem_inp_pred_port.index(),
rem_inp_pred_port,
*new_inp_node,
rep_inp_port.offset.index(),
*rep_inp_port,
)
.unwrap();
}
Expand All @@ -147,9 +147,9 @@ impl Rewrite for SimpleReplacement {
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
h.connect(
*new_out_node,
rep_out_pred_port.index(),
rep_out_pred_port,
*rem_out_node,
rem_out_port.index(),
*rem_out_port,
)
.unwrap();
}
Expand All @@ -169,9 +169,9 @@ impl Rewrite for SimpleReplacement {
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
h.connect(
rem_inp_pred_node,
rem_inp_pred_port.index(),
rem_inp_pred_port,
*rem_out_node,
rem_out_port.index(),
*rem_out_port,
)
.unwrap();
}
Expand Down
2 changes: 1 addition & 1 deletion src/hugr/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use portgraph::{Direction, LinkError, NodeIndex, PortView};

use serde::{Deserialize, Deserializer, Serialize};

use super::{HugrError, HugrMut, HugrView};
use super::{HugrError, HugrMut, HugrView, PortIndex};

/// A wrapper over the available HUGR serialization formats.
///
Expand Down
3 changes: 2 additions & 1 deletion src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod leaf;
pub mod module;
pub mod tag;
pub mod validate;
use crate::hugr::PortIndex;
use crate::types::{EdgeKind, FunctionType, SignatureDescription, Type};
use crate::{Direction, Port};

Expand Down Expand Up @@ -77,7 +78,7 @@ impl OpType {
/// Returns the edge kind for the given port.
pub fn port_kind(&self, port: impl Into<Port>) -> Option<EdgeKind> {
let signature = self.signature();
let port = port.into();
let port: Port = port.into();
let dir = port.direction();

let port_count = signature.port_count(dir);
Expand Down
10 changes: 5 additions & 5 deletions src/types/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use smol_str::SmolStr;

use std::fmt::{self, Display, Write};

use crate::hugr::Direction;
use crate::hugr::{Direction, PortIndex};

use super::{Type, TypeRow};

Expand Down Expand Up @@ -105,8 +105,8 @@ impl FunctionType {
#[inline]
pub fn get(&self, port: Port) -> Option<&Type> {
match port.direction() {
Direction::Incoming => self.input.get(port.index()),
Direction::Outgoing => self.output.get(port.index()),
Direction::Incoming => self.input.get(port),
Direction::Outgoing => self.output.get(port),
}
}

Expand All @@ -115,8 +115,8 @@ impl FunctionType {
#[inline]
pub fn get_mut(&mut self, port: Port) -> Option<&mut Type> {
match port.direction() {
Direction::Incoming => self.input.get_mut(port.index()),
Direction::Outgoing => self.output.get_mut(port.index()),
Direction::Incoming => self.input.get_mut(port),
Direction::Outgoing => self.output.get_mut(port),
}
}

Expand Down
22 changes: 13 additions & 9 deletions src/types/type_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
};

use super::Type;
use crate::hugr::PortIndex;
use crate::utils::display_list;
use delegate::delegate;

Expand Down Expand Up @@ -42,6 +43,18 @@ impl TypeRow {
}
}

#[inline(always)]
/// Returns the port type given an offset. Returns `None` if the offset is out of bounds.
pub fn get(&self, offset: impl PortIndex) -> Option<&Type> {
self.types.get(offset.index())
}

#[inline(always)]
/// Returns the port type given an offset. Returns `None` if the offset is out of bounds.
pub fn get_mut(&mut self, offset: impl PortIndex) -> Option<&mut Type> {
self.types.to_mut().get_mut(offset.index())
}

delegate! {
to self.types {
/// Iterator over the types in the row.
Expand All @@ -56,19 +69,10 @@ impl TypeRow {
/// Allow access (consumption) of the contained elements
pub fn into_owned(self) -> Vec<Type>;

/// Returns the port type given an offset. Returns `None` if the offset is out of bounds.
pub fn get(&self, offset: usize) -> Option<&Type>;

/// Returns `true` if the row contains no types.
pub fn is_empty(&self) -> bool ;
}
}

#[inline(always)]
/// Returns the port type given an offset. Returns `None` if the offset is out of bounds.
pub fn get_mut(&mut self, offset: usize) -> Option<&mut Type> {
self.types.to_mut().get_mut(offset)
}
}

impl Default for TypeRow {
Expand Down

0 comments on commit ffa2f0c

Please sign in to comment.