From ffa2f0cb96548197a2e7eb51a961eaca2d7624f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 22 Sep 2023 13:04:26 +0200 Subject: [PATCH] feat: PortIndex trait for undirected port parameters (#553) Add a `PortIndex` trait implemented by both `Port` and `usize`, to use in all the methods that expect an undirected port offset (instead of asking for an `usize`). --- src/builder/build_traits.rs | 16 ++++++---------- src/hugr.rs | 21 +++++++++++++++++++-- src/hugr/hugrmut.rs | 14 +++++++------- src/hugr/rewrite/insert_identity.rs | 4 ++-- src/hugr/rewrite/outline_cfg.rs | 8 +++----- src/hugr/rewrite/simple_replace.rs | 14 +++++++------- src/hugr/serialize.rs | 2 +- src/ops.rs | 3 ++- src/types/signature.rs | 10 +++++----- src/types/type_row.rs | 22 +++++++++++++--------- 10 files changed, 65 insertions(+), 49 deletions(-) diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 1f90d142d..7d9e2926c 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -1,7 +1,7 @@ use crate::hugr::hugrmut::InsertionResult; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; -use crate::hugr::{Node, NodeMetadata, Port, ValidationError}; +use crate::hugr::{Node, NodeMetadata, Port, PortIndex, ValidationError}; use crate::ops::{self, LeafOp, OpTrait, OpType}; use std::iter; @@ -649,13 +649,7 @@ fn wire_up_inputs( data_builder: &mut T, ) -> Result<(), BuildError> { for (dst_port, wire) in inputs.into_iter().enumerate() { - wire_up( - data_builder, - wire.node(), - wire.source().index(), - op_node, - dst_port, - )?; + wire_up(data_builder, wire.node(), wire.source(), op_node, dst_port)?; } Ok(()) } @@ -664,10 +658,12 @@ fn wire_up_inputs( fn wire_up( data_builder: &mut T, src: Node, - src_port: usize, + src_port: impl PortIndex, dst: Node, - dst_port: usize, + dst_port: impl PortIndex, ) -> Result { + let src_port = src_port.index(); + let dst_port = dst_port.index(); let base = data_builder.hugr_mut(); let src_offset = Port::new_outgoing(src_port); diff --git a/src/hugr.rs b/src/hugr.rs index d912b3c66..6f78fe43f 100644 --- a/src/hugr.rs +++ b/src/hugr.rs @@ -207,6 +207,15 @@ pub struct Port { offset: portgraph::PortOffset, } +/// A trait for getting the undirected index of a port. +/// +/// This allows functions to admit both [`Port`]s and explicit `usize`s for +/// identifying port offsets. +pub trait PortIndex { + /// Returns the offset of the port. + fn index(self) -> usize; +} + /// The direction of a port. pub type Direction = portgraph::Direction; @@ -382,14 +391,22 @@ impl Port { pub fn direction(self) -> Direction { self.offset.direction() } +} - /// Returns the offset of the port. +impl PortIndex for Port { #[inline(always)] - pub fn index(self) -> usize { + fn index(self) -> usize { self.offset.index() } } +impl PortIndex for usize { + #[inline(always)] + fn index(self) -> usize { + self + } +} + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] /// A DataFlow wire, defined by a Value-kind output port of a node // Stores node and offset to output port diff --git a/src/hugr/hugrmut.rs b/src/hugr/hugrmut.rs index 2ae0f0078..3eb23e63b 100644 --- a/src/hugr/hugrmut.rs +++ b/src/hugr/hugrmut.rs @@ -12,7 +12,7 @@ use crate::{Hugr, Port}; use self::sealed::HugrMutInternals; -use super::{NodeMetadata, Rewrite}; +use super::{NodeMetadata, PortIndex, Rewrite}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrView + HugrMutInternals { @@ -108,9 +108,9 @@ pub trait HugrMut: HugrView + HugrMutInternals { fn connect( &mut self, src: Node, - src_port: usize, + src_port: impl PortIndex, dst: Node, - dst_port: usize, + dst_port: impl PortIndex, ) -> Result<(), HugrError> { self.valid_node(src)?; self.valid_node(dst)?; @@ -234,13 +234,13 @@ where fn connect( &mut self, src: Node, - src_port: usize, + src_port: impl PortIndex, dst: Node, - dst_port: usize, + dst_port: impl PortIndex, ) -> Result<(), HugrError> { self.as_mut() .graph - .link_nodes(src.index, src_port, dst.index, dst_port)?; + .link_nodes(src.index, src_port.index(), dst.index, dst_port.index())?; Ok(()) } @@ -265,7 +265,7 @@ where .get_optype(dst) .other_port_index(Direction::Incoming) .expect("Destination operation has no non-dataflow incoming edges"); - self.connect(src, src_port.index(), dst, dst_port.index())?; + self.connect(src, src_port, dst, dst_port)?; Ok((src_port, dst_port)) } diff --git a/src/hugr/rewrite/insert_identity.rs b/src/hugr/rewrite/insert_identity.rs index 7d04ccc3b..6c2e1072c 100644 --- a/src/hugr/rewrite/insert_identity.rs +++ b/src/hugr/rewrite/insert_identity.rs @@ -91,10 +91,10 @@ impl Rewrite for IdentityInsertion { let new_node = h .add_op_with_parent(parent, LeafOp::Noop { ty }) .expect("Parent validity already checked."); - h.connect(pre_node, pre_port.index(), new_node, 0) + h.connect(pre_node, pre_port, new_node, 0) .expect("Should only fail if ports don't exist."); - h.connect(new_node, 0, self.post_node, self.post_port.index()) + h.connect(new_node, 0, self.post_node, self.post_port) .expect("Should only fail if ports don't exist."); Ok(new_node) } diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index bfa8fa753..bd28e0246 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -9,7 +9,7 @@ use crate::extension::ExtensionSet; use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::rewrite::Rewrite; use crate::hugr::views::sibling::SiblingMut; -use crate::hugr::{HugrMut, HugrView}; +use crate::hugr::{HugrMut, HugrView, PortIndex}; use crate::ops; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; use crate::ops::{BasicBlock, OpTrait, OpType}; @@ -155,7 +155,7 @@ impl Rewrite for OutlineCfg { for (pred, br) in preds { if !self.blocks.contains(&pred) { h.disconnect(pred, br).unwrap(); - h.connect(pred, br.index(), new_block, 0).unwrap(); + h.connect(pred, br, new_block, 0).unwrap(); } } if entry == outer_entry { @@ -204,9 +204,7 @@ impl Rewrite for OutlineCfg { SiblingMut::try_new(h, new_block).unwrap(); let mut in_cfg_view: SiblingMut<'_, CfgID> = SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap(); - in_cfg_view - .connect(exit, exit_port.index(), inner_exit, 0) - .unwrap(); + in_cfg_view.connect(exit, exit_port, inner_exit, 0).unwrap(); Ok(()) } diff --git a/src/hugr/rewrite/simple_replace.rs b/src/hugr/rewrite/simple_replace.rs index 6d127def5..c40f9776c 100644 --- a/src/hugr/rewrite/simple_replace.rs +++ b/src/hugr/rewrite/simple_replace.rs @@ -107,7 +107,7 @@ impl Rewrite for SimpleReplacement { for target in self.replacement.linked_ports(node, outport) { if self.replacement.get_optype(target.0).tag() != OpTag::Output { let new_target = index_map.get(&target.0).unwrap(); - h.connect(*new_node, outport.index(), *new_target, target.1.index()) + h.connect(*new_node, outport, *new_target, target.1) .unwrap(); } } @@ -127,9 +127,9 @@ impl Rewrite for SimpleReplacement { let new_inp_node = index_map.get(rep_inp_node).unwrap(); h.connect( rem_inp_pred_node, - rem_inp_pred_port.index(), + rem_inp_pred_port, *new_inp_node, - rep_inp_port.offset.index(), + *rep_inp_port, ) .unwrap(); } @@ -147,9 +147,9 @@ impl Rewrite for SimpleReplacement { h.disconnect(*rem_out_node, *rem_out_port).unwrap(); h.connect( *new_out_node, - rep_out_pred_port.index(), + rep_out_pred_port, *rem_out_node, - rem_out_port.index(), + *rem_out_port, ) .unwrap(); } @@ -169,9 +169,9 @@ impl Rewrite for SimpleReplacement { h.disconnect(*rem_out_node, *rem_out_port).unwrap(); h.connect( rem_inp_pred_node, - rem_inp_pred_port.index(), + rem_inp_pred_port, *rem_out_node, - rem_out_port.index(), + *rem_out_port, ) .unwrap(); } diff --git a/src/hugr/serialize.rs b/src/hugr/serialize.rs index 53bff8cc6..89b2da36f 100644 --- a/src/hugr/serialize.rs +++ b/src/hugr/serialize.rs @@ -18,7 +18,7 @@ use portgraph::{Direction, LinkError, NodeIndex, PortView}; use serde::{Deserialize, Deserializer, Serialize}; -use super::{HugrError, HugrMut, HugrView}; +use super::{HugrError, HugrMut, HugrView, PortIndex}; /// A wrapper over the available HUGR serialization formats. /// diff --git a/src/ops.rs b/src/ops.rs index 0334becd3..ad99e26b3 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -9,6 +9,7 @@ pub mod leaf; pub mod module; pub mod tag; pub mod validate; +use crate::hugr::PortIndex; use crate::types::{EdgeKind, FunctionType, SignatureDescription, Type}; use crate::{Direction, Port}; @@ -77,7 +78,7 @@ impl OpType { /// Returns the edge kind for the given port. pub fn port_kind(&self, port: impl Into) -> Option { let signature = self.signature(); - let port = port.into(); + let port: Port = port.into(); let dir = port.direction(); let port_count = signature.port_count(dir); diff --git a/src/types/signature.rs b/src/types/signature.rs index 076ba8815..38a816092 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -9,7 +9,7 @@ use smol_str::SmolStr; use std::fmt::{self, Display, Write}; -use crate::hugr::Direction; +use crate::hugr::{Direction, PortIndex}; use super::{Type, TypeRow}; @@ -105,8 +105,8 @@ impl FunctionType { #[inline] pub fn get(&self, port: Port) -> Option<&Type> { match port.direction() { - Direction::Incoming => self.input.get(port.index()), - Direction::Outgoing => self.output.get(port.index()), + Direction::Incoming => self.input.get(port), + Direction::Outgoing => self.output.get(port), } } @@ -115,8 +115,8 @@ impl FunctionType { #[inline] pub fn get_mut(&mut self, port: Port) -> Option<&mut Type> { match port.direction() { - Direction::Incoming => self.input.get_mut(port.index()), - Direction::Outgoing => self.output.get_mut(port.index()), + Direction::Incoming => self.input.get_mut(port), + Direction::Outgoing => self.output.get_mut(port), } } diff --git a/src/types/type_row.rs b/src/types/type_row.rs index b3b69bafe..8a2bedcf6 100644 --- a/src/types/type_row.rs +++ b/src/types/type_row.rs @@ -8,6 +8,7 @@ use std::{ }; use super::Type; +use crate::hugr::PortIndex; use crate::utils::display_list; use delegate::delegate; @@ -42,6 +43,18 @@ impl TypeRow { } } + #[inline(always)] + /// Returns the port type given an offset. Returns `None` if the offset is out of bounds. + pub fn get(&self, offset: impl PortIndex) -> Option<&Type> { + self.types.get(offset.index()) + } + + #[inline(always)] + /// Returns the port type given an offset. Returns `None` if the offset is out of bounds. + pub fn get_mut(&mut self, offset: impl PortIndex) -> Option<&mut Type> { + self.types.to_mut().get_mut(offset.index()) + } + delegate! { to self.types { /// Iterator over the types in the row. @@ -56,19 +69,10 @@ impl TypeRow { /// Allow access (consumption) of the contained elements pub fn into_owned(self) -> Vec; - /// Returns the port type given an offset. Returns `None` if the offset is out of bounds. - pub fn get(&self, offset: usize) -> Option<&Type>; - /// Returns `true` if the row contains no types. pub fn is_empty(&self) -> bool ; } } - - #[inline(always)] - /// Returns the port type given an offset. Returns `None` if the offset is out of bounds. - pub fn get_mut(&mut self, offset: usize) -> Option<&mut Type> { - self.types.to_mut().get_mut(offset) - } } impl Default for TypeRow {