From 7b6c16a262381c98d64b973f03c84144fb68f588 Mon Sep 17 00:00:00 2001 From: Matt Keeter Date: Fri, 30 Aug 2024 08:31:21 -0400 Subject: [PATCH] API plumbing for multi-output tapes (#163) --- CHANGELOG.md | 7 +- demos/constraints/src/main.rs | 7 +- fidget/src/core/compiler/ssa_tape.rs | 37 ++- fidget/src/core/context/mod.rs | 4 +- fidget/src/core/eval/bulk.rs | 2 +- fidget/src/core/eval/mod.rs | 9 +- fidget/src/core/eval/test/float_slice.rs | 20 +- fidget/src/core/eval/test/grad_slice.rs | 44 +-- fidget/src/core/eval/test/interval.rs | 310 ++++++++++---------- fidget/src/core/eval/test/point.rs | 175 ++++++----- fidget/src/core/eval/test/symbolic_deriv.rs | 10 +- fidget/src/core/eval/tracing.rs | 2 +- fidget/src/core/shape/mod.rs | 17 +- fidget/src/core/vm/data.rs | 17 +- fidget/src/core/vm/mod.rs | 31 +- fidget/src/jit/mod.rs | 50 +++- fidget/src/solver/mod.rs | 20 +- 17 files changed, 429 insertions(+), 333 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56a44aa9..31851b1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # 0.3.3 -(nothing here yet) +- `Function` and evaluator types now produce multiple outputs + - `MathFunction::new` now takes a slice of nodes, instead of a single node + - All of the intermediate tape formats (`SsaTape`, etc) are aware of + multiple output nodes + - Evaluation now returns a slice of outputs, one for each root node (ordered + based on order in the `&[Node]` slice passed to `MathFunction::new`) # 0.3.2 - Added `impl IntoNode for Var`, to make handling `Var` values in a context diff --git a/demos/constraints/src/main.rs b/demos/constraints/src/main.rs index 620bccda..8d21a36b 100644 --- a/demos/constraints/src/main.rs +++ b/demos/constraints/src/main.rs @@ -87,7 +87,7 @@ impl ConstraintsApp { .into_iter() .map(|eqn| { let root = ctx.import(&eqn); - fidget::vm::VmFunction::new(&ctx, root).unwrap() + fidget::vm::VmFunction::new(&ctx, &[root]).unwrap() }) .collect::>(); @@ -180,8 +180,9 @@ impl eframe::App for ConstraintsApp { if *dragged { let v = ctx.var(var); let weight = ctx.sub(v, p).unwrap(); - let f = fidget::vm::VmFunction::new(&ctx, weight) - .unwrap(); + let f = + fidget::vm::VmFunction::new(&ctx, &[weight]) + .unwrap(); constraints.push(f); } } diff --git a/fidget/src/core/compiler/ssa_tape.rs b/fidget/src/core/compiler/ssa_tape.rs index 6136bc7f..c5289330 100644 --- a/fidget/src/core/compiler/ssa_tape.rs +++ b/fidget/src/core/compiler/ssa_tape.rs @@ -26,6 +26,9 @@ pub struct SsaTape { /// Number of choice operations in the tape pub choice_count: usize, + + /// Number of output operations in the tape + pub output_count: usize, } impl SsaTape { @@ -33,7 +36,7 @@ impl SsaTape { /// /// This should always succeed unless the `root` is from a different /// `Context`, in which case `Error::BadNode` will be returned. - pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> { + pub fn new(ctx: &Context, roots: &[Node]) -> Result<(Self, VarMap), Error> { let mut mapping = HashMap::new(); let mut parent_count: HashMap = HashMap::new(); let mut slot_count = 0; @@ -48,7 +51,7 @@ impl SsaTape { // Accumulate parent counts and declare all nodes let mut seen = HashSet::new(); let mut vars = VarMap::new(); - let mut todo = vec![root]; + let mut todo = roots.to_vec(); while let Some(node) = todo.pop() { if !seen.insert(node) { continue; @@ -76,15 +79,18 @@ impl SsaTape { // Now that we've populated our parents, flatten the graph let mut seen = HashSet::new(); - let mut todo = vec![root]; + let mut todo = roots.to_vec(); let mut choice_count = 0; let mut tape = vec![]; - match mapping[&root] { - Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, 0)), - Slot::Immediate(imm) => { - tape.push(SsaOp::Output(0, 0)); - tape.push(SsaOp::CopyImm(0, imm)); + for (i, r) in roots.iter().enumerate() { + let i = i as u32; + match mapping[r] { + Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, i)), + Slot::Immediate(imm) => { + tape.push(SsaOp::Output(0, i)); + tape.push(SsaOp::CopyImm(0, imm)); + } } } @@ -235,7 +241,14 @@ impl SsaTape { tape.push(op); } - Ok((SsaTape { tape, choice_count }, vars)) + Ok(( + SsaTape { + tape, + choice_count, + output_count: roots.len(), + }, + vars, + )) } /// Checks whether the tape is empty @@ -410,7 +423,7 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let (tape, vs) = SsaTape::new(&ctx, c9).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, &[c9]).unwrap(); assert_eq!(tape.len(), 9); assert_eq!(vs.len(), 2); } @@ -421,7 +434,7 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let (tape, vs) = SsaTape::new(&ctx, x_squared).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, &[x_squared]).unwrap(); assert_eq!(tape.len(), 3); // x, square, output assert_eq!(vs.len(), 1); } @@ -430,7 +443,7 @@ mod test { fn test_constant() { let mut ctx = Context::new(); let p = ctx.constant(1.5); - let (tape, vs) = SsaTape::new(&ctx, p).unwrap(); + let (tape, vs) = SsaTape::new(&ctx, &[p]).unwrap(); assert_eq!(tape.len(), 2); // CopyImm, output assert_eq!(vs.len(), 0); } diff --git a/fidget/src/core/context/mod.rs b/fidget/src/core/context/mod.rs index 02a37075..2a9e375e 100644 --- a/fidget/src/core/context/mod.rs +++ b/fidget/src/core/context/mod.rs @@ -1566,7 +1566,7 @@ mod test { let c8 = ctx.sub(c7, r).unwrap(); let c9 = ctx.max(c8, c6).unwrap(); - let tape = VmData::<255>::new(&ctx, c9).unwrap(); + let tape = VmData::<255>::new(&ctx, &[c9]).unwrap(); assert_eq!(tape.len(), 9); assert_eq!(tape.vars.len(), 2); } @@ -1577,7 +1577,7 @@ mod test { let x = ctx.x(); let x_squared = ctx.mul(x, x).unwrap(); - let tape = VmData::<255>::new(&ctx, x_squared).unwrap(); + let tape = VmData::<255>::new(&ctx, &[x_squared]).unwrap(); assert_eq!(tape.len(), 3); // x, square, output assert_eq!(tape.vars.len(), 1); } diff --git a/fidget/src/core/eval/bulk.rs b/fidget/src/core/eval/bulk.rs index d62a8900..4e2bb72c 100644 --- a/fidget/src/core/eval/bulk.rs +++ b/fidget/src/core/eval/bulk.rs @@ -55,7 +55,7 @@ pub trait BulkEvaluator: Default { /// Container for bulk output results /// /// This container represents an array-of-arrays. It is indexed first by -/// variable, then by index within the evaluation array. +/// output index, then by index within the evaluation array. pub struct BulkOutput<'a, T> { data: &'a Vec>, len: usize, diff --git a/fidget/src/core/eval/mod.rs b/fidget/src/core/eval/mod.rs index cd78599c..7d18ef75 100644 --- a/fidget/src/core/eval/mod.rs +++ b/fidget/src/core/eval/mod.rs @@ -32,6 +32,13 @@ pub trait Tape { /// Returns a mapping from [`Var`](crate::var::Var) to evaluation index fn vars(&self) -> &VarMap; + + /// Returns the number of outputs written by this tape + /// + /// The order of outputs is set by the caller at tape construction, so we + /// don't need a map to determine the index of a particular output (unlike + /// variables). + fn output_count(&self) -> usize; } /// Represents the trace captured by a tracing evaluation @@ -175,7 +182,7 @@ pub trait Function: Send + Sync + Clone { /// A [`Function`] which can be built from a math expression pub trait MathFunction: Function { /// Builds a new function from the given context and node - fn new(ctx: &Context, node: Node) -> Result + fn new(ctx: &Context, nodes: &[Node]) -> Result where Self: Sized; } diff --git a/fidget/src/core/eval/test/float_slice.rs b/fidget/src/core/eval/test/float_slice.rs index b008817e..b1f7046d 100644 --- a/fidget/src/core/eval/test/float_slice.rs +++ b/fidget/src/core/eval/test/float_slice.rs @@ -24,8 +24,8 @@ impl TestFloatSlice { let x = ctx.x(); let x1 = ctx.add(x, 1.0).unwrap(); - let shape_x = F::new(&ctx, x).unwrap(); - let shape_x1 = F::new(&ctx, x1).unwrap(); + let shape_x = F::new(&ctx, &[x]).unwrap(); + let shape_x1 = F::new(&ctx, &[x1]).unwrap(); // This is a fuzz test for icache issues let mut eval = F::new_float_slice_eval(); @@ -53,7 +53,7 @@ impl TestFloatSlice { let y = ctx.y(); let mut eval = F::new_float_slice_eval(); - let shape = F::new(&ctx, x).unwrap(); + let shape = F::new(&ctx, &[x]).unwrap(); let tape = shape.float_slice_tape(Default::default()); let out = eval .eval(&tape, &[[0.0, 1.0, 2.0, 3.0].as_slice()]) @@ -77,7 +77,7 @@ impl TestFloatSlice { assert_eq!(&out[0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); let mul = ctx.mul(y, 2.0).unwrap(); - let shape = F::new(&ctx, mul).unwrap(); + let shape = F::new(&ctx, &[mul]).unwrap(); let tape = shape.float_slice_tape(Default::default()); let out = eval .eval(&tape, &[[3.0, 2.0, 1.0, 0.0].as_slice()]) @@ -98,7 +98,7 @@ impl TestFloatSlice { let a = ctx.x(); let b = ctx.sin(a).unwrap(); - let shape = F::new(&ctx, b).unwrap(); + let shape = F::new(&ctx, &[b]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); @@ -164,7 +164,7 @@ impl TestFloatSlice { let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); @@ -222,7 +222,7 @@ impl TestFloatSlice { let node = C::build(&mut ctx, v); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); @@ -271,7 +271,7 @@ impl TestFloatSlice { rgsa.rotate_left(rot); let node = C::build(&mut ctx, va, vb); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); let vars = tape.vars(); @@ -309,7 +309,7 @@ impl TestFloatSlice { for rhs in args.iter() { let node = C::build(&mut ctx, va, *rhs); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); @@ -340,7 +340,7 @@ impl TestFloatSlice { for lhs in args.iter() { let node = C::build(&mut ctx, *lhs, va); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_float_slice_eval(); let tape = shape.float_slice_tape(Default::default()); diff --git a/fidget/src/core/eval/test/grad_slice.rs b/fidget/src/core/eval/test/grad_slice.rs index 37a756c7..11d26acb 100644 --- a/fidget/src/core/eval/test/grad_slice.rs +++ b/fidget/src/core/eval/test/grad_slice.rs @@ -56,7 +56,7 @@ impl TestGradSlice { pub fn test_g_x() { let mut ctx = Context::new(); let x = ctx.x(); - let shape = F::new(&ctx, x).unwrap(); + let shape = F::new(&ctx, &[x]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -68,7 +68,7 @@ impl TestGradSlice { pub fn test_g_y() { let mut ctx = Context::new(); let y = ctx.y(); - let shape = F::new(&ctx, y).unwrap(); + let shape = F::new(&ctx, &[y]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -80,7 +80,7 @@ impl TestGradSlice { pub fn test_g_z() { let mut ctx = Context::new(); let z = ctx.z(); - let shape = F::new(&ctx, z).unwrap(); + let shape = F::new(&ctx, &[z]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -93,7 +93,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.square(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -118,7 +118,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.abs(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -135,7 +135,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sqrt(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -152,7 +152,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let v = Self::eval_xyz(&tape, &[1.0, 2.0, 3.0], &[0.0; 3], &[0.0; 3]); @@ -163,7 +163,7 @@ impl TestGradSlice { let y = ctx.y(); let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let v = Self::eval_xyz(&tape, &[0.0; 3], &[1.0, 2.0, 3.0], &[0.0; 3]); v[0].compare_eq(Grad::new(2f32.sin(), 0.0, 2.0 * 2f32.cos(), 0.0)); @@ -176,7 +176,7 @@ impl TestGradSlice { let x = ctx.x(); let y = ctx.y(); let s = ctx.mul(x, y).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -201,7 +201,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.div(x, 2.0).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -214,7 +214,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.recip(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -232,7 +232,7 @@ impl TestGradSlice { let x = ctx.x(); let y = ctx.y(); let m = ctx.min(x, y).unwrap(); - let shape = F::new(&ctx, m).unwrap(); + let shape = F::new(&ctx, &[m]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -252,7 +252,7 @@ impl TestGradSlice { let z = ctx.z(); let min = ctx.min(x, y).unwrap(); let max = ctx.max(min, z).unwrap(); - let shape = F::new(&ctx, max).unwrap(); + let shape = F::new(&ctx, &[max]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -274,7 +274,7 @@ impl TestGradSlice { let x = ctx.x(); let y = ctx.y(); let m = ctx.max(x, y).unwrap(); - let shape = F::new(&ctx, m).unwrap(); + let shape = F::new(&ctx, &[m]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -291,7 +291,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let x = ctx.x(); let m = ctx.not(x).unwrap(); - let shape = F::new(&ctx, m).unwrap(); + let shape = F::new(&ctx, &[m]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -310,7 +310,7 @@ impl TestGradSlice { let sum = ctx.add(x2, y2).unwrap(); let sqrt = ctx.sqrt(sum).unwrap(); let sub = ctx.sub(sqrt, 0.5).unwrap(); - let shape = F::new(&ctx, sub).unwrap(); + let shape = F::new(&ctx, &[sub]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); assert_eq!( @@ -342,7 +342,7 @@ impl TestGradSlice { let z: Vec = args[2..].iter().chain(&args[0..2]).cloned().collect(); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let out = Self::eval_xyz(&tape, &x, &y, &z); @@ -366,7 +366,7 @@ impl TestGradSlice { // Compare against the VmShape evaluator as a baseline. It's possible // that S is also a VmShape, but this comparison isn't particularly // expensive, so we'll do it regardless. - let shape = VmFunction::new(&ctx, node).unwrap(); + let shape = VmFunction::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let cmp = TestGradSlice::::eval_xyz(&tape, &x, &y, &z); @@ -387,7 +387,7 @@ impl TestGradSlice { let mut ctx = Context::new(); let v = ctx.var(Var::new()); let node = C::build(&mut ctx, v); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let mut eval = F::new_grad_slice_eval(); @@ -546,7 +546,7 @@ impl TestGradSlice { for (j, &u) in inputs.iter().enumerate() { let node = C::build(&mut ctx, v, u); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let out = match (i, j) { @@ -592,7 +592,7 @@ impl TestGradSlice { for rhs in args.iter() { let node = C::build(&mut ctx, v, *rhs); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let out = match i { @@ -632,7 +632,7 @@ impl TestGradSlice { for lhs in args.iter() { let node = C::build(&mut ctx, *lhs, v); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let out = match i { diff --git a/fidget/src/core/eval/test/interval.rs b/fidget/src/core/eval/test/interval.rs index 6cb6db29..7dca29d7 100644 --- a/fidget/src/core/eval/test/interval.rs +++ b/fidget/src/core/eval/test/interval.rs @@ -29,27 +29,27 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = F::new(&ctx, x).unwrap(); + let shape = F::new(&ctx, &[x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [0.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 5.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 5.0].into()]).unwrap().0[0], [1.0, 5.0].into() ); - let shape = F::new(&ctx, y).unwrap(); + let shape = F::new(&ctx, &[y]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[2.0, 3.0].into()]).unwrap().0, + eval.eval(&tape, &[[2.0, 3.0].into()]).unwrap().0[0], [2.0, 3.0].into() ); assert_eq!( - eval.eval(&tape, &[[4.0, 5.0].into()]).unwrap().0, + eval.eval(&tape, &[[4.0, 5.0].into()]).unwrap().0[0], [4.0, 5.0].into() ); } @@ -59,47 +59,47 @@ where let x = ctx.x(); let abs_x = ctx.abs(x).unwrap(); - let shape = F::new(&ctx, abs_x).unwrap(); + let shape = F::new(&ctx, &[abs_x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [0.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 5.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 5.0].into()]).unwrap().0[0], [1.0, 5.0].into() ); assert_eq!( - eval.eval(&tape, &[[-2.0, 5.0].into()]).unwrap().0, + eval.eval(&tape, &[[-2.0, 5.0].into()]).unwrap().0[0], [0.0, 5.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, 5.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, 5.0].into()]).unwrap().0[0], [0.0, 6.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, -1.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, -1.0].into()]).unwrap().0[0], [1.0, 6.0].into() ); let y = ctx.y(); let abs_y = ctx.abs(y).unwrap(); let sum = ctx.add(abs_x, abs_y).unwrap(); - let shape = F::new(&ctx, sum).unwrap(); + let shape = F::new(&ctx, &[sum]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0[0], [0.0, 2.0].into() ); assert_eq!( - eval.eval(&tape, &vs([1.0, 5.0], [-2.0, 3.0])).unwrap().0, + eval.eval(&tape, &vs([1.0, 5.0], [-2.0, 3.0])).unwrap().0[0], [1.0, 8.0].into() ); assert_eq!( - eval.eval(&tape, &vs([1.0, 5.0], [-4.0, 3.0])).unwrap().0, + eval.eval(&tape, &vs([1.0, 5.0], [-4.0, 3.0])).unwrap().0[0], [1.0, 9.0].into() ); } @@ -110,12 +110,12 @@ where let v = ctx.add(x, 0.5).unwrap(); let out = ctx.abs(v).unwrap(); - let shape = F::new(&ctx, out).unwrap(); + let shape = F::new(&ctx, &[out]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[-1.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[-1.0, 1.0].into()]).unwrap().0[0], [0.0, 1.5].into() ); } @@ -125,29 +125,29 @@ where let x = ctx.x(); let sqrt_x = ctx.sqrt(x).unwrap(); - let shape = F::new(&ctx, sqrt_x).unwrap(); + let shape = F::new(&ctx, &[sqrt_x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [0.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0[0], [0.0, 2.0].into() ); // Even a partial negative returns a NAN interval - let nanan = eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0; + let nanan = eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); // Full negatives are right out - let nanan = eval.eval(&tape, &[[-2.0, -1.0].into()]).unwrap().0; + let nanan = eval.eval(&tape, &[[-2.0, -1.0].into()]).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let (v, _) = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap(); + let v = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); } @@ -157,35 +157,35 @@ where let x = ctx.x(); let sqrt_x = ctx.square(x).unwrap(); - let shape = F::new(&ctx, sqrt_x).unwrap(); + let shape = F::new(&ctx, &[sqrt_x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [0.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0[0], [0.0, 16.0].into() ); assert_eq!( - eval.eval(&tape, &[[2.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[2.0, 4.0].into()]).unwrap().0[0], [4.0, 16.0].into() ); assert_eq!( - eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0[0], [0.0, 16.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, -2.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, -2.0].into()]).unwrap().0[0], [4.0, 36.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, 1.0].into()]).unwrap().0[0], [0.0, 36.0].into() ); - let (v, _) = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap(); + let v = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); } @@ -194,12 +194,12 @@ where let mut ctx = Context::new(); let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [-1.0, 1.0].into() ); @@ -207,12 +207,12 @@ where let y = ctx.mul(y, 2.0).unwrap(); let s = ctx.sin(y).unwrap(); let s = ctx.add(x, s).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); assert_eq!( - eval.eval(&tape, &vs([0.0, 3.0], [0.0, 0.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 3.0], [0.0, 0.0])).unwrap().0[0], [-1.0, 4.0].into() ); } @@ -222,35 +222,35 @@ where let x = ctx.x(); let neg_x = ctx.neg(x).unwrap(); - let shape = F::new(&ctx, neg_x).unwrap(); + let shape = F::new(&ctx, &[neg_x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [-1.0, 0.0].into() ); assert_eq!( - eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 4.0].into()]).unwrap().0[0], [-4.0, 0.0].into() ); assert_eq!( - eval.eval(&tape, &[[2.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[2.0, 4.0].into()]).unwrap().0[0], [-4.0, -2.0].into() ); assert_eq!( - eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0, + eval.eval(&tape, &[[-2.0, 4.0].into()]).unwrap().0[0], [-4.0, 2.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, -2.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, -2.0].into()]).unwrap().0[0], [2.0, 6.0].into() ); assert_eq!( - eval.eval(&tape, &[[-6.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[-6.0, 1.0].into()]).unwrap().0[0], [-1.0, 6.0].into() ); - let (v, _) = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap(); + let v = eval.eval(&tape, &[[f32::NAN; 2].into()]).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); } @@ -260,11 +260,11 @@ where let x = ctx.x(); let not_x = ctx.not(x).unwrap(); - let shape = F::new(&ctx, not_x).unwrap(); + let shape = F::new(&ctx, &[not_x]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[-5.0, 0.0].into()]).unwrap().0, + eval.eval(&tape, &[[-5.0, 0.0].into()]).unwrap().0[0], [0.0, 1.0].into() ); } @@ -275,36 +275,36 @@ where let y = ctx.y(); let mul = ctx.mul(x, y).unwrap(); - let shape = F::new(&ctx, mul).unwrap(); + let shape = F::new(&ctx, &[mul]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0[0], [0.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &vs([0.0, 1.0], [0.0, 2.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 1.0], [0.0, 2.0])).unwrap().0[0], [0.0, 2.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-2.0, 1.0], [0.0, 1.0])).unwrap().0, + eval.eval(&tape, &vs([-2.0, 1.0], [0.0, 1.0])).unwrap().0[0], [-2.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-2.0, -1.0], [-5.0, -4.0])).unwrap().0, + eval.eval(&tape, &vs([-2.0, -1.0], [-5.0, -4.0])).unwrap().0[0], [4.0, 10.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-3.0, -1.0], [-2.0, 6.0])).unwrap().0, + eval.eval(&tape, &vs([-3.0, -1.0], [-2.0, 6.0])).unwrap().0[0], [-18.0, 6.0].into() ); - let (v, _) = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap(); + let v = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); - let (v, _) = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap(); + let v = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); } @@ -313,28 +313,28 @@ where let mut ctx = Context::new(); let x = ctx.x(); let mul = ctx.mul(x, 2.0).unwrap(); - let shape = F::new(&ctx, mul).unwrap(); + let shape = F::new(&ctx, &[mul]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [0.0, 2.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0[0], [2.0, 4.0].into() ); let mul = ctx.mul(x, -3.0).unwrap(); - let shape = F::new(&ctx, mul).unwrap(); + let shape = F::new(&ctx, &[mul]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [-3.0, 0.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0[0], [-6.0, -3.0].into() ); } @@ -345,28 +345,28 @@ where let y = ctx.y(); let sub = ctx.sub(x, y).unwrap(); - let shape = F::new(&ctx, sub).unwrap(); + let shape = F::new(&ctx, &[sub]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 1.0], [0.0, 1.0])).unwrap().0[0], [-1.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &vs([0.0, 1.0], [0.0, 2.0])).unwrap().0, + eval.eval(&tape, &vs([0.0, 1.0], [0.0, 2.0])).unwrap().0[0], [-2.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-2.0, 1.0], [0.0, 1.0])).unwrap().0, + eval.eval(&tape, &vs([-2.0, 1.0], [0.0, 1.0])).unwrap().0[0], [-3.0, 1.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-2.0, -1.0], [-5.0, -4.0])).unwrap().0, + eval.eval(&tape, &vs([-2.0, -1.0], [-5.0, -4.0])).unwrap().0[0], [2.0, 4.0].into() ); assert_eq!( - eval.eval(&tape, &vs([-3.0, -1.0], [-2.0, 6.0])).unwrap().0, + eval.eval(&tape, &vs([-3.0, -1.0], [-2.0, 6.0])).unwrap().0[0], [-9.0, 1.0].into() ); } @@ -375,28 +375,28 @@ where let mut ctx = Context::new(); let x = ctx.x(); let sub = ctx.sub(x, 2.0).unwrap(); - let shape = F::new(&ctx, sub).unwrap(); + let shape = F::new(&ctx, &[sub]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [-2.0, -1.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0[0], [-1.0, 0.0].into() ); let sub = ctx.sub(-3.0, x).unwrap(); - let shape = F::new(&ctx, sub).unwrap(); + let shape = F::new(&ctx, &[sub]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); assert_eq!( - eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0, + eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0], [-4.0, -3.0].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0[0], [-5.0, -4.0].into() ); } @@ -405,28 +405,28 @@ where let mut ctx = Context::new(); let x = ctx.x(); let recip = ctx.recip(x).unwrap(); - let shape = F::new(&ctx, recip).unwrap(); + let shape = F::new(&ctx, &[recip]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); - let nanan = eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0; + let nanan = eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let nanan = eval.eval(&tape, &[[-1.0, 0.0].into()]).unwrap().0; + let nanan = eval.eval(&tape, &[[-1.0, 0.0].into()]).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let nanan = eval.eval(&tape, &[[-2.0, 3.0].into()]).unwrap().0; + let nanan = eval.eval(&tape, &[[-2.0, 3.0].into()]).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); assert_eq!( - eval.eval(&tape, &[[-2.0, -1.0].into()]).unwrap().0, + eval.eval(&tape, &[[-2.0, -1.0].into()]).unwrap().0[0], [-1.0, -0.5].into() ); assert_eq!( - eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0, + eval.eval(&tape, &[[1.0, 2.0].into()]).unwrap().0[0], [0.5, 1.0].into() ); } @@ -436,40 +436,43 @@ where let x = ctx.x(); let y = ctx.y(); let div = ctx.div(x, y).unwrap(); - let shape = F::new(&ctx, div).unwrap(); + let shape = F::new(&ctx, &[div]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); - let nanan = eval.eval(&tape, &vs([0.0, 1.0], [-1.0, 1.0])).unwrap().0; + let nanan = + eval.eval(&tape, &vs([0.0, 1.0], [-1.0, 1.0])).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let nanan = eval.eval(&tape, &vs([0.0, 1.0], [-2.0, 0.0])).unwrap().0; + let nanan = + eval.eval(&tape, &vs([0.0, 1.0], [-2.0, 0.0])).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let nanan = eval.eval(&tape, &vs([0.0, 1.0], [0.0, 4.0])).unwrap().0; + let nanan = eval.eval(&tape, &vs([0.0, 1.0], [0.0, 4.0])).unwrap().0[0]; assert!(nanan.lower().is_nan()); assert!(nanan.upper().is_nan()); - let out = eval.eval(&tape, &vs([-1.0, 0.0], [1.0, 2.0])).unwrap().0; + let out = eval.eval(&tape, &vs([-1.0, 0.0], [1.0, 2.0])).unwrap().0[0]; assert_eq!(out, [-1.0, 0.0].into()); - let out = eval.eval(&tape, &vs([-1.0, 4.0], [-1.0, -0.5])).unwrap().0; + let out = + eval.eval(&tape, &vs([-1.0, 4.0], [-1.0, -0.5])).unwrap().0[0]; assert_eq!(out, [-8.0, 2.0].into()); - let out = eval.eval(&tape, &vs([1.0, 4.0], [-1.0, -0.5])).unwrap().0; + let out = eval.eval(&tape, &vs([1.0, 4.0], [-1.0, -0.5])).unwrap().0[0]; assert_eq!(out, [-8.0, -1.0].into()); - let out = eval.eval(&tape, &vs([-1.0, 4.0], [0.5, 1.0])).unwrap().0; + let out = eval.eval(&tape, &vs([-1.0, 4.0], [0.5, 1.0])).unwrap().0[0]; assert_eq!(out, [-2.0, 8.0].into()); - let (v, _) = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap(); + let v = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); - let (v, _) = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap(); + let v = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap().0[0]; assert!(v.lower().is_nan()); assert!(v.upper().is_nan()); } @@ -480,32 +483,32 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); let (r, data) = eval.eval(&tape, &vs([0.0, 1.0], [0.5, 1.5])).unwrap(); - assert_eq!(r, [0.0, 1.0].into()); + assert_eq!(r[0], [0.0, 1.0].into()); assert!(data.is_none()); let (r, data) = eval.eval(&tape, &vs([0.0, 1.0], [2.0, 3.0])).unwrap(); - assert_eq!(r, [0.0, 1.0].into()); + assert_eq!(r[0], [0.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); let (r, data) = eval.eval(&tape, &vs([2.0, 3.0], [0.0, 1.0])).unwrap(); - assert_eq!(r, [0.0, 1.0].into()); + assert_eq!(r[0], [0.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let (v, data) = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap(); - assert!(v.lower().is_nan()); - assert!(v.upper().is_nan()); + assert!(v[0].lower().is_nan()); + assert!(v[0].upper().is_nan()); assert!(data.is_none()); let (v, data) = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap(); - assert!(v.lower().is_nan()); - assert!(v.upper().is_nan()); + assert!(v[0].lower().is_nan()); + assert!(v[0].upper().is_nan()); assert!(data.is_none()); } @@ -514,19 +517,19 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); let (r, data) = eval.eval(&tape, &[[0.0, 1.0].into()]).unwrap(); - assert_eq!(r, [0.0, 1.0].into()); + assert_eq!(r[0], [0.0, 1.0].into()); assert!(data.is_none()); let (r, data) = eval.eval(&tape, &[[-1.0, 0.0].into()]).unwrap(); - assert_eq!(r, [-1.0, 0.0].into()); + assert_eq!(r[0], [-1.0, 0.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); let (r, data) = eval.eval(&tape, &[[2.0, 3.0].into()]).unwrap(); - assert_eq!(r, [1.0, 1.0].into()); + assert_eq!(r[0], [1.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); } @@ -536,56 +539,56 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = F::new(&ctx, max).unwrap(); + let shape = F::new(&ctx, &[max]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); let (r, data) = eval.eval(&tape, &vs([0.0, 1.0], [0.5, 1.5])).unwrap(); - assert_eq!(r, [0.5, 1.5].into()); + assert_eq!(r[0], [0.5, 1.5].into()); assert!(data.is_none()); let (r, data) = eval.eval(&tape, &vs([0.0, 1.0], [2.0, 3.0])).unwrap(); - assert_eq!(r, [2.0, 3.0].into()); + assert_eq!(r[0], [2.0, 3.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let (r, data) = eval.eval(&tape, &vs([2.0, 3.0], [0.0, 1.0])).unwrap(); - assert_eq!(r, [2.0, 3.0].into()); + assert_eq!(r[0], [2.0, 3.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); let (v, data) = eval.eval(&tape, &vs([f32::NAN; 2], [0.0, 1.0])).unwrap(); - assert!(v.lower().is_nan()); - assert!(v.upper().is_nan()); + assert!(v[0].lower().is_nan()); + assert!(v[0].upper().is_nan()); assert!(data.is_none()); let (v, data) = eval.eval(&tape, &vs([0.0, 1.0], [f32::NAN; 2])).unwrap(); - assert!(v.lower().is_nan()); - assert!(v.upper().is_nan()); + assert!(v[0].lower().is_nan()); + assert!(v[0].upper().is_nan()); assert!(data.is_none()); let z = ctx.z(); let max_xy_z = ctx.max(max, z).unwrap(); - let shape = F::new(&ctx, max_xy_z).unwrap(); + let shape = F::new(&ctx, &[max_xy_z]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xyz(&tape); let mut eval = F::new_interval_eval(); let (r, data) = eval .eval(&tape, &vs([2.0, 3.0], [0.0, 1.0], [4.0, 5.0])) .unwrap(); - assert_eq!(r, [4.0, 5.0].into()); + assert_eq!(r[0], [4.0, 5.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left, Choice::Right]); let (r, data) = eval .eval(&tape, &vs([2.0, 3.0], [0.0, 1.0], [1.0, 4.0])) .unwrap(); - assert_eq!(r, [2.0, 4.0].into()); + assert_eq!(r[0], [2.0, 4.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left, Choice::Both]); let (r, data) = eval .eval(&tape, &vs([2.0, 3.0], [0.0, 1.0], [1.0, 1.5])) .unwrap(); - assert_eq!(r, [2.0, 3.0].into()); + assert_eq!(r[0], [2.0, 3.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left, Choice::Left]); } @@ -595,28 +598,28 @@ where let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = F::new(&ctx, v).unwrap(); + let shape = F::new(&ctx, &[v]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); let (r, trace) = eval.eval(&tape, &vs([0.0, 0.0], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [0.0, 0.0].into()); + assert_eq!(r[0], [0.0, 0.0].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs([-1.0, -0.2], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [-1.0, 3.0].into()); + assert_eq!(r[0], [-1.0, 3.0].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs([0.2, 1.3], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [-1.0, 3.0].into()); + assert_eq!(r[0], [-1.0, 3.0].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs([-0.2, 1.3], [1.0, 3.0])).unwrap(); - assert_eq!(r, [0.0, 3.0].into()); + assert_eq!(r[0], [0.0, 3.0].into()); assert!(trace.is_none()); // can't simplify } @@ -626,28 +629,28 @@ where let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = F::new(&ctx, v).unwrap(); + let shape = F::new(&ctx, &[v]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); let (r, trace) = eval.eval(&tape, &vs([0.0, 0.0], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [-1.0, 3.0].into()); + assert_eq!(r[0], [-1.0, 3.0].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs([-1.0, -0.2], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [-1.0, -0.2].into()); + assert_eq!(r[0], [-1.0, -0.2].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs([0.2, 1.3], [-1.0, 3.0])).unwrap(); - assert_eq!(r, [0.2, 1.3].into()); + assert_eq!(r[0], [0.2, 1.3].into()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs([-0.2, 1.3], [1.0, 3.0])).unwrap(); - assert_eq!(r, [-0.2, 3.0].into()); + assert_eq!(r[0], [-0.2, 3.0].into()); assert!(trace.is_none()); } @@ -656,35 +659,35 @@ where let x = ctx.x(); let min = ctx.min(x, 1.0).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); let (out, data) = eval.eval(&tape, &[[0.0, 2.0].into()]).unwrap(); - assert_eq!(out, [0.0, 1.0].into()); + assert_eq!(out[0], [0.0, 1.0].into()); assert!(data.is_none()); let (out, data) = eval.eval(&tape, &[[0.0, 0.5].into()]).unwrap(); - assert_eq!(out, [0.0, 0.5].into()); + assert_eq!(out[0], [0.0, 0.5].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); let (out, data) = eval.eval(&tape, &[[1.5, 2.5].into()]).unwrap(); - assert_eq!(out, [1.0, 1.0].into()); + assert_eq!(out[0], [1.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let max = ctx.max(x, 1.0).unwrap(); - let shape = F::new(&ctx, max).unwrap(); + let shape = F::new(&ctx, &[max]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); let (out, data) = eval.eval(&tape, &[[0.0, 2.0].into()]).unwrap(); - assert_eq!(out, [1.0, 2.0].into()); + assert_eq!(out[0], [1.0, 2.0].into()); assert!(data.is_none()); let (out, data) = eval.eval(&tape, &[[0.0, 0.5].into()]).unwrap(); - assert_eq!(out, [1.0, 1.0].into()); + assert_eq!(out[0], [1.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let (out, data) = eval.eval(&tape, &[[1.5, 2.5].into()]).unwrap(); - assert_eq!(out, [1.5, 2.5].into()); + assert_eq!(out[0], [1.5, 2.5].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); } @@ -695,7 +698,7 @@ where let z = ctx.z(); let if_else = ctx.if_nonzero_else(x, y, z).unwrap(); - let shape = F::new(&ctx, if_else).unwrap(); + let shape = F::new(&ctx, &[if_else]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xyz(&tape); @@ -706,14 +709,14 @@ where // Alas, we lose the information that the conditional is correlated, so // the interval naively must include 0 - assert_eq!(out, [0.0, 4.0].into()); + assert_eq!(out[0], [0.0, 4.0].into()); assert!(data.is_none()); // Confirm that simplification of the right side works let (out, data) = eval .eval(&tape, &vs([0.0, 0.0], [1.0, 2.0], [3.0, 4.0])) .unwrap(); - assert_eq!(out, [3.0, 4.0].into()); + assert_eq!(out[0], [3.0, 4.0].into()); let s_z = shape .simplify( data.expect("must have trace"), @@ -726,13 +729,13 @@ where .eval(&t_z, &vs([-1.0, 1.0], [1.0, 2.0], [5.0, 6.0])) .unwrap(); assert!(s_z.size() < shape.size()); - assert_eq!(out, [5.0, 6.0].into()); + assert_eq!(out[0], [5.0, 6.0].into()); assert!(data.is_none()); let (out, data) = eval .eval(&tape, &vs([1.0, 3.0], [1.0, 2.0], [3.0, 4.0])) .unwrap(); - assert_eq!(out, [1.0, 2.0].into()); + assert_eq!(out[0], [1.0, 2.0].into()); assert!(data.is_some()); let s_y = shape .simplify( @@ -746,7 +749,7 @@ where .eval(&t_y, &vs([-1.0, 1.0], [1.0, 4.0], [5.0, 6.0])) .unwrap(); assert!(s_y.size() < shape.size()); - assert_eq!(out, [1.0, 4.0].into()); + assert_eq!(out[0], [1.0, 4.0].into()); assert!(data.is_none()); assert_eq!(s_y.size(), s_z.size()) @@ -757,19 +760,19 @@ where let x = ctx.x(); let max = ctx.max(x, 1.0).unwrap(); - let shape = F::new(&ctx, max).unwrap(); + let shape = F::new(&ctx, &[max]).unwrap(); let tape = shape.interval_tape(Default::default()); let mut eval = F::new_interval_eval(); let (r, data) = eval.eval(&tape, &[[0.0, 2.0].into()]).unwrap(); - assert_eq!(r, [1.0, 2.0].into()); + assert_eq!(r[0], [1.0, 2.0].into()); assert!(data.is_none()); let (r, data) = eval.eval(&tape, &[[-1.0, 0.0].into()]).unwrap(); - assert_eq!(r, [1.0, 1.0].into()); + assert_eq!(r[0], [1.0, 1.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Right]); let (r, data) = eval.eval(&tape, &[[2.0, 3.0].into()]).unwrap(); - assert_eq!(r, [2.0, 3.0].into()); + assert_eq!(r[0], [2.0, 3.0].into()); assert_eq!(data.unwrap().as_ref(), &[Choice::Left]); } @@ -779,12 +782,12 @@ where let y = ctx.y(); let c = ctx.compare(x, y).unwrap(); - let shape = F::new(&ctx, c).unwrap(); + let shape = F::new(&ctx, &[c]).unwrap(); let tape = shape.interval_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_interval_eval(); let (out, _trace) = eval.eval(&tape, &vs(-5.0, -6.0)).unwrap(); - assert_eq!(out, Interval::from(1f32)); + assert_eq!(out[0], Interval::from(1f32)); } pub fn test_i_stress_n(depth: usize) { @@ -800,14 +803,14 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_interval_eval(); let tape = shape.interval_tape(Default::default()); let vs = bind_xyz(&tape); let mut out = vec![]; for i in 0..args.len() { - out.push(eval.eval(&tape, &vs(x[i], y[i], z[i])).unwrap().0); + out.push(eval.eval(&tape, &vs(x[i], y[i], z[i])).unwrap().0[0]); } // Compare against the VmShape evaluator as a baseline. It's possible @@ -856,12 +859,13 @@ where let v = ctx.var(Var::new()); let node = C::build(&mut ctx, v); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.interval_tape(Default::default()); assert_eq!(tape.vars().len(), 1); for &a in args.iter() { let (o, trace) = eval.eval(&tape, &[a]).unwrap(); + let o = o[0]; assert!(trace.is_none()); for i in 0..32 { @@ -954,7 +958,7 @@ where continue; } - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); let vars = tape.vars(); @@ -973,7 +977,7 @@ where Self::compare_interval_results( lhs, rhs, - out, + out[0], C::eval_reg_reg_f32, &name, ); @@ -990,7 +994,7 @@ where continue; } - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); let vars = tape.vars(); @@ -1005,7 +1009,7 @@ where Self::compare_interval_results( lhs, lhs, - out, + out[0], C::eval_reg_reg_f32, &name, ); @@ -1027,7 +1031,7 @@ where for &rhs in values.iter() { let node = C::build(&mut ctx, a, rhs); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); let (out, _trace) = eval.eval(&tape, &[lhs]).unwrap(); @@ -1036,7 +1040,7 @@ where Self::compare_interval_results( lhs, rhs.into(), - out, + out[0], C::eval_reg_imm_f32, &name, ); @@ -1058,7 +1062,7 @@ where for &rhs in args.iter() { let node = C::build(&mut ctx, lhs, va); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let tape = shape.interval_tape(tape_data.unwrap_or_default()); let (out, _trace) = eval.eval(&tape, &[rhs]).unwrap(); @@ -1067,7 +1071,7 @@ where Self::compare_interval_results( lhs.into(), rhs, - out, + out[0], C::eval_imm_reg_f32, &name, ); diff --git a/fidget/src/core/eval/test/point.rs b/fidget/src/core/eval/test/point.rs index 10db98d7..80b78624 100644 --- a/fidget/src/core/eval/test/point.rs +++ b/fidget/src/core/eval/test/point.rs @@ -26,20 +26,20 @@ where pub fn test_constant() { let mut ctx = Context::new(); let p = ctx.constant(1.5); - let shape = F::new(&ctx, p).unwrap(); + let shape = F::new(&ctx, &[p]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &[]).unwrap().0, 1.5); + assert_eq!(eval.eval(&tape, &[]).unwrap().0[0], 1.5); } pub fn test_constant_push() { let mut ctx = Context::new(); let min = ctx.min(1.5, Var::X).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); let (r, trace) = eval.eval(&tape, &[2.0]).unwrap(); - assert_eq!(r, 1.5); + assert_eq!(r[0], 1.5); let next = shape .simplify( @@ -51,8 +51,8 @@ where assert_eq!(next.size(), 2); // constant, output let tape = next.point_tape(Default::default()); - assert_eq!(eval.eval(&tape, &[2.0]).unwrap().0, 1.5); - assert_eq!(eval.eval(&tape, &[1.0]).unwrap().0, 1.5); + assert_eq!(eval.eval(&tape, &[2.0]).unwrap().0[0], 1.5); + assert_eq!(eval.eval(&tape, &[1.0]).unwrap().0[0], 1.5); assert!(eval.eval(&tape, &[]).is_err()); } @@ -65,11 +65,11 @@ where let radius = ctx.add(x_squared, y_squared).unwrap(); let circle = ctx.sub(radius, 1.0).unwrap(); - let shape = F::new(&ctx, circle).unwrap(); + let shape = F::new(&ctx, &[circle]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &[0.0, 0.0]).unwrap().0, -1.0); - assert_eq!(eval.eval(&tape, &[1.0, 0.0]).unwrap().0, 0.0); + assert_eq!(eval.eval(&tape, &[0.0, 0.0]).unwrap().0[0], -1.0); + assert_eq!(eval.eval(&tape, &[1.0, 0.0]).unwrap().0[0], 0.0); } pub fn test_p_min() { @@ -78,29 +78,29 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); let (r, trace) = eval.eval(&tape, &vs(0.0, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &vs(0.0, 1.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(2.0, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(f32::NAN, 0.0)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &vs(0.0, f32::NAN)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert!(trace.is_none()); } @@ -110,29 +110,29 @@ where let y = ctx.y(); let max = ctx.max(x, y).unwrap(); - let shape = F::new(&ctx, max).unwrap(); + let shape = F::new(&ctx, &[max]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); let (r, trace) = eval.eval(&tape, &vs(0.0, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &vs(0.0, 1.0)).unwrap(); - assert_eq!(r, 1.0); + assert_eq!(r[0], 1.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(2.0, 0.0)).unwrap(); - assert_eq!(r, 2.0); + assert_eq!(r[0], 2.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(f32::NAN, 0.0)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &vs(0.0, f32::NAN)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert!(trace.is_none()); } @@ -142,33 +142,33 @@ where let y = ctx.y(); let v = ctx.and(x, y).unwrap(); - let shape = F::new(&ctx, v).unwrap(); + let shape = F::new(&ctx, &[v]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); let (r, trace) = eval.eval(&tape, &vs(0.0, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(0.0, 1.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(0.0, f32::NAN)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(0.1, 1.0)).unwrap(); - assert_eq!(r, 1.0); + assert_eq!(r[0], 1.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(0.1, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(f32::NAN, 1.2)).unwrap(); - assert_eq!(r, 1.2); + assert_eq!(r[0], 1.2); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); } @@ -178,33 +178,33 @@ where let y = ctx.y(); let v = ctx.or(x, y).unwrap(); - let shape = F::new(&ctx, v).unwrap(); + let shape = F::new(&ctx, &[v]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); let (r, trace) = eval.eval(&tape, &vs(0.0, 0.0)).unwrap(); - assert_eq!(r, 0.0); + assert_eq!(r[0], 0.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(0.0, 1.0)).unwrap(); - assert_eq!(r, 1.0); + assert_eq!(r[0], 1.0); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(0.0, f32::NAN)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Right]); let (r, trace) = eval.eval(&tape, &vs(0.1, 1.0)).unwrap(); - assert_eq!(r, 0.1); + assert_eq!(r[0], 0.1); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(0.1, 0.0)).unwrap(); - assert_eq!(r, 0.1); + assert_eq!(r[0], 0.1); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); let (r, trace) = eval.eval(&tape, &vs(f32::NAN, 1.2)).unwrap(); - assert!(r.is_nan()); + assert!(r[0].is_nan()); assert_eq!(trace.unwrap().as_ref(), &[Choice::Left]); } @@ -213,33 +213,33 @@ where let x = ctx.x(); let s = ctx.sin(x).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); for x in [0.0, 1.0, 2.0] { let (r, trace) = eval.eval(&tape, &[x]).unwrap(); - assert_eq!(r, x.sin()); + assert_eq!(r[0], x.sin()); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &[x]).unwrap(); - assert_eq!(r, x.sin()); + assert_eq!(r[0], x.sin()); assert!(trace.is_none()); let (r, trace) = eval.eval(&tape, &[x]).unwrap(); - assert_eq!(r, x.sin()); + assert_eq!(r[0], x.sin()); assert!(trace.is_none()); } let y = ctx.y(); let s = ctx.add(s, y).unwrap(); - let shape = F::new(&ctx, s).unwrap(); + let shape = F::new(&ctx, &[s]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); for (x, y) in [(0.0, 1.0), (1.0, 3.0), (2.0, 8.0)] { let (r, trace) = eval.eval(&tape, &vs(x, y)).unwrap(); - assert_eq!(r, x.sin() + y); + assert_eq!(r[0], x.sin() + y); assert!(trace.is_none()); } } @@ -250,13 +250,13 @@ where let y = ctx.y(); let sum = ctx.add(x, 1.0).unwrap(); let min = ctx.min(sum, y).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, 2.0); - assert_eq!(eval.eval(&tape, &vs(1.0, 3.0)).unwrap().0, 2.0); - assert_eq!(eval.eval(&tape, &vs(3.0, 3.5)).unwrap().0, 3.5); + assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, [2.0]); + assert_eq!(eval.eval(&tape, &vs(1.0, 3.0)).unwrap().0, [2.0]); + assert_eq!(eval.eval(&tape, &vs(3.0, 3.5)).unwrap().0, [3.5]); } pub fn test_push() { @@ -265,12 +265,12 @@ where let y = ctx.y(); let min = ctx.min(x, y).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, 1.0); - assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, 2.0); + assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, [1.0]); + assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, [2.0]); let next = shape .simplify( @@ -281,8 +281,8 @@ where .unwrap(); let tape = next.point_tape(Default::default()); let vs = bind_xy(&tape); - assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, 1.0); - assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, 3.0); + assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, [1.0]); + assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, [3.0]); let next = shape .simplify( @@ -293,15 +293,15 @@ where .unwrap(); let tape = next.point_tape(Default::default()); let vs = bind_xy(&tape); - assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, 2.0); - assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, 2.0); + assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, [2.0]); + assert_eq!(eval.eval(&tape, &vs(3.0, 2.0)).unwrap().0, [2.0]); let min = ctx.min(x, 1.0).unwrap(); - let shape = F::new(&ctx, min).unwrap(); + let shape = F::new(&ctx, &[min]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, 0.5); - assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, 1.0); + assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, [0.5]); + assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, [1.0]); let next = shape .simplify( @@ -311,8 +311,8 @@ where ) .unwrap(); let tape = next.point_tape(Default::default()); - assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, 0.5); - assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, 3.0); + assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, [0.5]); + assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, [3.0]); let next = shape .simplify( @@ -322,8 +322,8 @@ where ) .unwrap(); let tape = next.point_tape(Default::default()); - assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, 1.0); - assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, 1.0); + assert_eq!(eval.eval(&tape, &[0.5]).unwrap().0, [1.0]); + assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, [1.0]); } pub fn test_basic() { @@ -331,26 +331,26 @@ where let x = ctx.x(); let y = ctx.y(); - let shape = F::new(&ctx, x).unwrap(); + let shape = F::new(&ctx, &[x]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &[1.0]).unwrap().0, 1.0); - assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, 3.0); + assert_eq!(eval.eval(&tape, &[1.0]).unwrap().0, [1.0]); + assert_eq!(eval.eval(&tape, &[3.0]).unwrap().0, [3.0]); - let shape = F::new(&ctx, y).unwrap(); + let shape = F::new(&ctx, &[y]).unwrap(); let tape = shape.point_tape(Default::default()); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &[2.0]).unwrap().0, 2.0); - assert_eq!(eval.eval(&tape, &[4.0]).unwrap().0, 4.0); + assert_eq!(eval.eval(&tape, &[2.0]).unwrap().0, [2.0]); + assert_eq!(eval.eval(&tape, &[4.0]).unwrap().0, [4.0]); let y2 = ctx.mul(y, 2.5).unwrap(); let sum = ctx.add(x, y2).unwrap(); - let shape = F::new(&ctx, sum).unwrap(); + let shape = F::new(&ctx, &[sum]).unwrap(); let tape = shape.point_tape(Default::default()); let vs = bind_xy(&tape); let mut eval = F::new_point_eval(); - assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, 6.0); + assert_eq!(eval.eval(&tape, &vs(1.0, 2.0)).unwrap().0, [6.0]); } pub fn test_p_shape_var() { @@ -388,14 +388,14 @@ where let y: Vec<_> = x[1..].iter().chain(&x[0..1]).cloned().collect(); let z: Vec<_> = x[2..].iter().chain(&x[0..2]).cloned().collect(); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); let vs = bind_xyz(&tape); let mut out = vec![]; for i in 0..args.len() { - out.push(eval.eval(&tape, &vs(x[i], y[i], z[i])).unwrap().0); + out.push(eval.eval(&tape, &vs(x[i], y[i], z[i])).unwrap().0[0]); } for (i, v) in out.iter().cloned().enumerate() { @@ -445,7 +445,7 @@ where for v in [ctx.x(), ctx.y(), ctx.z(), ctx.var(Var::new())].into_iter() { let node = C::build(&mut ctx, v); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); @@ -453,6 +453,7 @@ where let (o, trace) = eval.eval(&tape, &[a]).unwrap(); assert!(trace.is_none()); let v = C::eval_f32(a); + let o = o[0]; let err = (v - o).abs(); assert!( (o == v) || err < 1e-6 || (v.is_nan() && o.is_nan()), @@ -494,7 +495,7 @@ where for &rhs in args.iter() { let node = C::build(&mut ctx, va, vb); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); let vars = tape.vars(); @@ -512,7 +513,7 @@ where Self::compare_point_results::( lhs, rhs, - out, + out[0], C::eval_reg_reg_f32, &name, ); @@ -522,7 +523,7 @@ where for &lhs in args.iter() { let node = C::build(&mut ctx, va, va); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); let vars = tape.vars(); @@ -536,7 +537,7 @@ where Self::compare_point_results::( lhs, lhs, - out, + out[0], C::eval_reg_reg_f32, &name, ); @@ -554,7 +555,7 @@ where for &rhs in args.iter() { let node = C::build(&mut ctx, va, rhs); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); @@ -563,7 +564,7 @@ where Self::compare_point_results::( lhs, rhs, - out, + out[0], C::eval_reg_imm_f32, &name, ); @@ -582,7 +583,7 @@ where for &rhs in args.iter() { let node = C::build(&mut ctx, lhs, va); - let shape = F::new(&ctx, node).unwrap(); + let shape = F::new(&ctx, &[node]).unwrap(); let mut eval = F::new_point_eval(); let tape = shape.point_tape(Default::default()); @@ -591,7 +592,7 @@ where Self::compare_point_results::( lhs, rhs, - out, + out[0], C::eval_imm_reg_f32, &name, ); @@ -604,6 +605,23 @@ where Self::test_binary_reg_imm::(); Self::test_binary_imm_reg::(); } + + pub fn test_multi_output() { + let mut ctx = Context::new(); + let x = ctx.x(); + let y = ctx.y(); + let a = ctx.min(x, y).unwrap(); + let b = ctx.max(x, y).unwrap(); + + let shape = F::new(&ctx, &[a, b]).unwrap(); + let mut eval = F::new_point_eval(); + let tape = shape.point_tape(Default::default()); + let vs = bind_xy(&tape); + + let (out, _trace) = eval.eval(&tape, &vs(1.0, 2.0)).unwrap(); + assert_eq!(out[0], 1.0); + assert_eq!(out[1], 2.0); + } } #[macro_export] @@ -632,6 +650,7 @@ macro_rules! point_tests { $crate::point_test!(test_basic, $t); $crate::point_test!(test_p_shape_var, $t); $crate::point_test!(test_p_stress, $t); + $crate::point_test!(test_multi_output, $t); mod p_unary { use super::*; diff --git a/fidget/src/core/eval/test/symbolic_deriv.rs b/fidget/src/core/eval/test/symbolic_deriv.rs index ca4da1dc..cbffe960 100644 --- a/fidget/src/core/eval/test/symbolic_deriv.rs +++ b/fidget/src/core/eval/test/symbolic_deriv.rs @@ -17,12 +17,12 @@ impl TestSymbolicDerivs { let mut ctx = Context::new(); let v = ctx.var(Var::new()); let node = C::build(&mut ctx, v); - let shape = VmFunction::new(&ctx, node).unwrap(); + let shape = VmFunction::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let mut eval = VmFunction::new_grad_slice_eval(); let node_deriv = ctx.deriv(node, ctx.get_var(v).unwrap()).unwrap(); - let shape_deriv = VmFunction::new(&ctx, node_deriv).unwrap(); + let shape_deriv = VmFunction::new(&ctx, &[node_deriv]).unwrap(); let tape_deriv = shape_deriv.float_slice_tape(Default::default()); let mut eval_deriv = VmFunction::new_float_slice_eval(); @@ -64,15 +64,15 @@ impl TestSymbolicDerivs { let mut eval_deriv = VmFunction::new_float_slice_eval(); let node = C::build(&mut ctx, a, b); - let shape = VmFunction::new(&ctx, node).unwrap(); + let shape = VmFunction::new(&ctx, &[node]).unwrap(); let tape = shape.grad_slice_tape(Default::default()); let node_a_deriv = ctx.deriv(node, va).unwrap(); - let shape_a_deriv = VmFunction::new(&ctx, node_a_deriv).unwrap(); + let shape_a_deriv = VmFunction::new(&ctx, &[node_a_deriv]).unwrap(); let tape_a_deriv = shape_a_deriv.float_slice_tape(Default::default()); let node_b_deriv = ctx.deriv(node, vb).unwrap(); - let shape_b_deriv = VmFunction::new(&ctx, node_b_deriv).unwrap(); + let shape_b_deriv = VmFunction::new(&ctx, &[node_b_deriv]).unwrap(); let tape_b_deriv = shape_b_deriv.float_slice_tape(Default::default()); for rot in 0..args.len() { diff --git a/fidget/src/core/eval/tracing.rs b/fidget/src/core/eval/tracing.rs index 0aae54fd..c5cb2abb 100644 --- a/fidget/src/core/eval/tracing.rs +++ b/fidget/src/core/eval/tracing.rs @@ -47,7 +47,7 @@ pub trait TracingEvaluator: Default { &mut self, tape: &Self::Tape, vars: &[Self::Data], - ) -> Result<(Self::Data, Option<&Self::Trace>), Error>; + ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error>; /// Build a new empty evaluator fn new() -> Self { diff --git a/fidget/src/core/shape/mod.rs b/fidget/src/core/shape/mod.rs index c2352ad8..515ee58e 100644 --- a/fidget/src/core/shape/mod.rs +++ b/fidget/src/core/shape/mod.rs @@ -309,7 +309,7 @@ impl Shape { node: Node, axes: [Var; 3], ) -> Result { - let f = F::new(ctx, node)?; + let f = F::new(ctx, &[node])?; Ok(Self { f, axes, @@ -409,6 +409,12 @@ where z: F, vars: &HashMap, ) -> Result<(E::Data, Option<&E::Trace>), Error> { + assert_eq!( + tape.tape.output_count(), + 1, + "ShapeTape has multiple outputs" + ); + let x = x.into(); let y = y.into(); let z = z.into(); @@ -449,7 +455,8 @@ where } } - self.eval.eval(&tape.tape, &self.scratch) + let (out, trace) = self.eval.eval(&tape.tape, &self.scratch)?; + Ok((out[0], trace)) } } @@ -496,6 +503,12 @@ where z: &[E::Data], vars: &HashMap, ) -> Result<&[E::Data], Error> { + assert_eq!( + tape.tape.output_count(), + 1, + "ShapeTape has multiple outputs" + ); + // Make sure our scratch arrays are big enough for this evaluation if x.len() != y.len() || x.len() != z.len() { return Err(Error::MismatchedSlices); diff --git a/fidget/src/core/vm/data.rs b/fidget/src/core/vm/data.rs index 7f207dd5..acb8d8e1 100644 --- a/fidget/src/core/vm/data.rs +++ b/fidget/src/core/vm/data.rs @@ -48,7 +48,7 @@ use std::sync::Arc; /// let tree = Tree::x() + Tree::y(); /// let mut ctx = Context::new(); /// let sum = ctx.import(&tree); -/// let data = VmData::<255>::new(&ctx, sum)?; +/// let data = VmData::<255>::new(&ctx, &[sum])?; /// assert_eq!(data.len(), 4); // X, Y, (X + Y), and output /// /// let mut iter = data.iter_asm(); @@ -76,8 +76,8 @@ pub struct VmData { impl VmData { /// Builds a new tape for the given node - pub fn new(context: &Context, node: Node) -> Result { - let (ssa, vars) = SsaTape::new(context, node)?; + pub fn new(context: &Context, nodes: &[Node]) -> Result { + let (ssa, vars) = SsaTape::new(context, nodes)?; let asm = RegTape::new::(&ssa); Ok(Self { ssa, @@ -104,6 +104,14 @@ impl VmData { self.ssa.choice_count } + /// Returns the number of output nodes in the tape. + /// + /// This is required because some evaluators pre-allocate spaces for the + /// output array. + pub fn output_count(&self) -> usize { + self.ssa.output_count + } + /// Returns the number of slots used by the inner VM tape pub fn slot_count(&self) -> usize { self.asm.slot_count() @@ -131,6 +139,7 @@ impl VmData { workspace.reset(self.ssa.tape.len(), tape.asm); let mut choice_count = 0; + let mut output_count = 0; // Other iterators to consume various arrays in order let mut choice_iter = choices.iter().rev(); @@ -143,6 +152,7 @@ impl VmData { *reg = workspace.get_or_insert_active(*reg); workspace.alloc.op(op); ops_out.push(op); + output_count += 1; continue; } _ => op.output().unwrap(), @@ -297,6 +307,7 @@ impl VmData { ssa: SsaTape { tape: ops_out, choice_count, + output_count, }, asm: asm_tape, vars: self.vars.clone(), diff --git a/fidget/src/core/vm/mod.rs b/fidget/src/core/vm/mod.rs index 7e17e488..9e963b6c 100644 --- a/fidget/src/core/vm/mod.rs +++ b/fidget/src/core/vm/mod.rs @@ -46,6 +46,10 @@ impl Tape for GenericVmFunction { fn vars(&self) -> &VarMap { &self.0.vars } + + fn output_count(&self) -> usize { + self.0.output_count() + } } /// A trace captured by a VM evaluation @@ -194,8 +198,8 @@ impl RenderHints for GenericVmFunction { } impl MathFunction for GenericVmFunction { - fn new(ctx: &Context, node: Node) -> Result { - let d = VmData::new(ctx, node)?; + fn new(ctx: &Context, nodes: &[Node]) -> Result { + let d = VmData::new(ctx, nodes)?; Ok(Self(d.into())) } } @@ -232,13 +236,15 @@ impl std::ops::IndexMut for SlotArray<'_, T> { /// Generic VM evaluator for tracing evaluation struct TracingVmEval { slots: Vec, + out: Vec, choices: VmTrace, } impl Default for TracingVmEval { fn default() -> Self { Self { - slots: vec![], + slots: Vec::default(), + out: Vec::default(), choices: VmTrace::default(), } } @@ -248,6 +254,7 @@ impl + Clone> TracingVmEval { fn resize_slots(&mut self, tape: &VmData) { self.slots.resize(tape.slot_count(), f32::NAN.into()); self.choices.resize(tape.choice_count(), Choice::Unknown); + self.out.resize(tape.output_count(), f32::NAN.into()); self.choices.fill(Choice::Unknown); } } @@ -265,7 +272,7 @@ impl TracingEvaluator for VmIntervalEval { &mut self, tape: &Self::Tape, vars: &[Interval], - ) -> Result<(Interval, Option<&VmTrace>), Error> { + ) -> Result<(&[Interval], Option<&VmTrace>), Error> { tape.vars().check_tracing_arguments(vars)?; let tape = tape.0.as_ref(); self.0.resize_slots(tape); @@ -273,13 +280,10 @@ impl TracingEvaluator for VmIntervalEval { let mut simplify = false; let mut v = SlotArray(&mut self.0.slots); let mut choices = self.0.choices.as_mut_slice().iter_mut(); - let mut out = None; for op in tape.iter_asm() { match op { RegOp::Output(arg, i) => { - assert_eq!(i, 0); - assert!(out.is_none()); - out = Some(v[arg]); + self.0.out[i as usize] = v[arg]; } RegOp::Input(out, i) => { v[out] = vars[i as usize]; @@ -477,7 +481,7 @@ impl TracingEvaluator for VmIntervalEval { } } Ok(( - out.unwrap(), + &self.0.out, if simplify { Some(&self.0.choices) } else { @@ -500,7 +504,7 @@ impl TracingEvaluator for VmPointEval { &mut self, tape: &Self::Tape, vars: &[f32], - ) -> Result<(f32, Option<&VmTrace>), Error> { + ) -> Result<(&[f32], Option<&VmTrace>), Error> { tape.vars().check_tracing_arguments(vars)?; let tape = tape.0.as_ref(); self.0.resize_slots(tape); @@ -508,13 +512,10 @@ impl TracingEvaluator for VmPointEval { let mut choices = self.0.choices.as_mut_slice().iter_mut(); let mut simplify = false; let mut v = SlotArray(&mut self.0.slots); - let mut out = None; for op in tape.iter_asm() { match op { RegOp::Output(arg, i) => { - assert_eq!(i, 0); - assert!(out.is_none()); - out = Some(v[arg]); + self.0.out[i as usize] = v[arg]; } RegOp::Input(out, i) => { v[out] = vars[i as usize]; @@ -778,7 +779,7 @@ impl TracingEvaluator for VmPointEval { } } Ok(( - out.unwrap(), + &self.0.out, if simplify { Some(&self.0.choices) } else { diff --git a/fidget/src/jit/mod.rs b/fidget/src/jit/mod.rs index a948a590..75b12cfd 100644 --- a/fidget/src/jit/mod.rs +++ b/fidget/src/jit/mod.rs @@ -849,6 +849,7 @@ impl JitFunction { mmap: f, vars: self.0.data().vars.clone(), choice_count: self.0.choice_count(), + output_count: self.0.output_count(), fn_trace: unsafe { std::mem::transmute(ptr) }, } } @@ -857,6 +858,7 @@ impl JitFunction { let ptr = f.as_ptr(); JitBulkFn { mmap: f, + output_count: self.0.output_count(), vars: self.0.data().vars.clone(), fn_bulk: unsafe { std::mem::transmute(ptr) }, } @@ -957,9 +959,18 @@ macro_rules! jit_fn { /// /// Users are unlikely to use this directly, but it's public because it's an /// associated type on [`JitFunction`]. -#[derive(Default)] -struct JitTracingEval { +struct JitTracingEval { choices: VmTrace, + out: Vec, +} + +impl Default for JitTracingEval { + fn default() -> Self { + Self { + choices: VmTrace::default(), + out: Vec::default(), + } + } } /// Handle to an owned function pointer for tracing evaluation @@ -967,6 +978,7 @@ pub struct JitTracingFn { #[allow(unused)] mmap: Mmap, choice_count: usize, + output_count: usize, vars: Arc, fn_trace: jit_fn!( unsafe fn( @@ -987,6 +999,10 @@ impl Tape for JitTracingFn { fn vars(&self) -> &VarMap { &self.vars } + + fn output_count(&self) -> usize { + self.output_count + } } // SAFETY: there is no mutable state in a `JitTracingFn`, and the pointer @@ -994,28 +1010,29 @@ impl Tape for JitTracingFn { unsafe impl Send for JitTracingFn {} unsafe impl Sync for JitTracingFn {} -impl JitTracingEval { +impl + Clone> JitTracingEval { /// Evaluates a single point, capturing an evaluation trace - fn eval>( + fn eval( &mut self, tape: &JitTracingFn, vars: &[T], - ) -> (T, Option<&VmTrace>) { + ) -> (&[T], Option<&VmTrace>) { let mut simplify = 0; self.choices.resize(tape.choice_count, Choice::Unknown); self.choices.fill(Choice::Unknown); - let mut out = f32::NAN.into(); + self.out.resize(tape.output_count, std::f32::NAN.into()); + self.out.fill(f32::NAN.into()); unsafe { (tape.fn_trace)( vars.as_ptr(), self.choices.as_mut_ptr() as *mut u8, &mut simplify, - &mut out, + self.out.as_mut_ptr(), ) }; ( - out, + &self.out, if simplify != 0 { Some(&self.choices) } else { @@ -1027,7 +1044,7 @@ impl JitTracingEval { /// JIT-based tracing evaluator for interval values #[derive(Default)] -pub struct JitIntervalEval(JitTracingEval); +pub struct JitIntervalEval(JitTracingEval); impl TracingEvaluator for JitIntervalEval { type Data = Interval; type Tape = JitTracingFn; @@ -1038,7 +1055,7 @@ impl TracingEvaluator for JitIntervalEval { &mut self, tape: &Self::Tape, vars: &[Self::Data], - ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { + ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> { tape.vars().check_tracing_arguments(vars)?; Ok(self.0.eval(tape, vars)) } @@ -1046,7 +1063,7 @@ impl TracingEvaluator for JitIntervalEval { /// JIT-based tracing evaluator for point values #[derive(Default)] -pub struct JitPointEval(JitTracingEval); +pub struct JitPointEval(JitTracingEval); impl TracingEvaluator for JitPointEval { type Data = f32; type Tape = JitTracingFn; @@ -1057,7 +1074,7 @@ impl TracingEvaluator for JitPointEval { &mut self, tape: &Self::Tape, vars: &[Self::Data], - ) -> Result<(Self::Data, Option<&Self::Trace>), Error> { + ) -> Result<(&[Self::Data], Option<&Self::Trace>), Error> { tape.vars().check_tracing_arguments(vars)?; Ok(self.0.eval(tape, vars)) } @@ -1070,6 +1087,7 @@ pub struct JitBulkFn { #[allow(unused)] mmap: Mmap, vars: Arc, + output_count: usize, fn_bulk: jit_fn!( unsafe fn( *const *const T, // vars @@ -1088,6 +1106,10 @@ impl Tape for JitBulkFn { fn vars(&self) -> &VarMap { &self.vars } + + fn output_count(&self) -> usize { + self.output_count + } } /// Maximum SIMD width for any type, checked at runtime (alas) @@ -1259,8 +1281,8 @@ impl BulkEvaluator for JitGradSliceEval { } impl MathFunction for JitFunction { - fn new(ctx: &Context, node: Node) -> Result { - GenericVmFunction::new(ctx, node).map(JitFunction) + fn new(ctx: &Context, nodes: &[Node]) -> Result { + GenericVmFunction::new(ctx, nodes).map(JitFunction) } } diff --git a/fidget/src/solver/mod.rs b/fidget/src/solver/mod.rs index 9818649f..968f84e3 100644 --- a/fidget/src/solver/mod.rs +++ b/fidget/src/solver/mod.rs @@ -168,7 +168,7 @@ impl<'a, F: Function> Solver<'a, F> { } // Do the actual gradient evaluation let (out, _t) = self.point_eval.eval(tape, &self.input_point)?; - err += out.powi(2); + err += out[0].powi(2); // TODO: consolidate into a single tape } Ok(err) } @@ -302,7 +302,7 @@ mod test { let mut ctx = Context::new(); let root = ctx.import(&eqn); - let f = VmFunction::new(&ctx, root).unwrap(); + let f = VmFunction::new(&ctx, &[root]).unwrap(); let mut values = HashMap::new(); values.insert(Var::X, Parameter::Free(0.0)); values.insert(Var::Y, Parameter::Fixed(-1.0)); @@ -321,7 +321,7 @@ mod test { let mut ctx = Context::new(); let root = ctx.import(&root); - let f = VmFunction::new(&ctx, root).unwrap(); + let f = VmFunction::new(&ctx, &[root]).unwrap(); let mut values = HashMap::new(); for (i, &v) in vs.iter().enumerate() { values.insert(v, Parameter::Free(i as f32)); @@ -343,7 +343,7 @@ mod test { for (i, &v) in vs.iter().enumerate() { let eqn = Tree::from(v) - Tree::from(i as f32); let root = ctx.import(&eqn); - let f = VmFunction::new(&ctx, root).unwrap(); + let f = VmFunction::new(&ctx, &[root]).unwrap(); eqns.push(f); } @@ -369,7 +369,7 @@ mod test { .into_iter() .map(|c| { let root = ctx.import(&c); - VmFunction::new(&ctx, root).unwrap() + VmFunction::new(&ctx, &[root]).unwrap() }) .collect::>(); @@ -395,7 +395,7 @@ mod test { .into_iter() .map(|c| { let root = ctx.import(&c); - VmFunction::new(&ctx, root).unwrap() + VmFunction::new(&ctx, &[root]).unwrap() }) .collect::>(); @@ -420,7 +420,7 @@ mod test { .into_iter() .map(|c| { let root = ctx.import(&c); - VmFunction::new(&ctx, root).unwrap() + VmFunction::new(&ctx, &[root]).unwrap() }) .collect::>(); @@ -444,7 +444,7 @@ mod test { let t = (Tree::x().square() + Tree::y().square()).sqrt(); let mut ctx = Context::new(); let root = ctx.import(&t); - let eqn = VmFunction::new(&ctx, root).unwrap(); + let eqn = VmFunction::new(&ctx, &[root]).unwrap(); let eqns = [eqn]; let mut values = HashMap::new(); @@ -487,7 +487,7 @@ mod test { out += *mat.get((row, col)).unwrap() * t.clone(); } let root = ctx.import(&out); - let f = VmFunction::new(&ctx, root).unwrap(); + let f = VmFunction::new(&ctx, &[root]).unwrap(); eqns.push(f); } @@ -573,7 +573,7 @@ mod test { } } let root = ctx.import(&out); - let f = VmFunction::new(&ctx, root).unwrap(); + let f = VmFunction::new(&ctx, &[root]).unwrap(); eqns.push(f); }