Skip to content

Commit

Permalink
refactor: make clear const folding only for leaf ops (#785)
Browse files Browse the repository at this point in the history
  • Loading branch information
ss2165 committed Jan 8, 2024
1 parent b680662 commit ca07831
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/algorithm/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
views::SiblingSubgraph,
HugrMut,
},
ops::{Const, LeafOp, OpType},
ops::{Const, LeafOp},
type_row,
types::{FunctionType, Type, TypeEnum},
values::Value,
Expand Down Expand Up @@ -44,9 +44,7 @@ pub(crate) fn sorted_consts(consts: &[(IncomingPort, Const)]) -> Vec<&Const> {
.collect()
}
/// For a given op and consts, attempt to evaluate the op.
pub fn fold_const(op: &OpType, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
let op = op.as_leaf_op()?;

pub fn fold_leaf_op(op: &LeafOp, consts: &[(IncomingPort, Const)]) -> ConstFoldResult {
match op {
LeafOp::Noop { .. } => out_row([consts.first()?.1.clone()]),
LeafOp::MakeTuple { .. } => {
Expand Down Expand Up @@ -138,16 +136,17 @@ fn fold_op(
op_node: Node,
reg: &ExtensionRegistry,
) -> Option<(SimpleReplacement, Vec<RemoveConstIgnore>)> {
// only support leaf folding for now.
let neighbour_op = hugr.get_optype(op_node).as_leaf_op()?;
let (in_consts, removals): (Vec<_>, Vec<_>) = hugr
.node_inputs(op_node)
.filter_map(|in_p| {
let (con_op, load_n) = get_const(hugr, op_node, in_p)?;
Some(((in_p, con_op), RemoveConstIgnore(load_n)))
})
.unzip();
let neighbour_op = hugr.get_optype(op_node);
// attempt to evaluate op
let folded = fold_const(neighbour_op, &in_consts)?;
let folded = fold_leaf_op(neighbour_op, &in_consts)?;
let (op_outs, consts): (Vec<_>, Vec<_>) = folded.into_iter().unzip();
let nu_out = op_outs
.into_iter()
Expand Down Expand Up @@ -220,6 +219,7 @@ mod test {
use super::*;
use crate::extension::prelude::sum_with_error;
use crate::extension::{ExtensionRegistry, PRELUDE};
use crate::ops::OpType;
use crate::std_extensions::arithmetic;
use crate::std_extensions::arithmetic::conversions::ConvertOpDef;
use crate::std_extensions::arithmetic::float_ops::FloatOps;
Expand Down Expand Up @@ -249,7 +249,7 @@ mod test {
fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) {
let consts = vec![(0.into(), f2c(a)), (1.into(), f2c(b))];
let add_op: OpType = FloatOps::fadd.into();
let out = fold_const(&add_op, &consts).unwrap();
let out = fold_leaf_op(add_op.as_leaf_op().unwrap(), &consts).unwrap();

assert_eq!(&out[..], &[(0.into(), f2c(c))]);
}
Expand Down

0 comments on commit ca07831

Please sign in to comment.