Skip to content

Commit

Permalink
API plumbing for multi-output tapes (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkeeter committed Aug 30, 2024
1 parent ff8d255 commit 7b6c16a
Show file tree
Hide file tree
Showing 17 changed files with 429 additions and 333 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# 0.3.3
(nothing here yet)
- `Function` and evaluator types now produce multiple outputs
- `MathFunction::new` now takes a slice of nodes, instead of a single node
- All of the intermediate tape formats (`SsaTape`, etc) are aware of
multiple output nodes
- Evaluation now returns a slice of outputs, one for each root node (ordered
based on order in the `&[Node]` slice passed to `MathFunction::new`)

# 0.3.2
- Added `impl IntoNode for Var`, to make handling `Var` values in a context
Expand Down
7 changes: 4 additions & 3 deletions demos/constraints/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl ConstraintsApp {
.into_iter()
.map(|eqn| {
let root = ctx.import(&eqn);
fidget::vm::VmFunction::new(&ctx, root).unwrap()
fidget::vm::VmFunction::new(&ctx, &[root]).unwrap()
})
.collect::<Vec<_>>();

Expand Down Expand Up @@ -180,8 +180,9 @@ impl eframe::App for ConstraintsApp {
if *dragged {
let v = ctx.var(var);
let weight = ctx.sub(v, p).unwrap();
let f = fidget::vm::VmFunction::new(&ctx, weight)
.unwrap();
let f =
fidget::vm::VmFunction::new(&ctx, &[weight])
.unwrap();
constraints.push(f);
}
}
Expand Down
37 changes: 25 additions & 12 deletions fidget/src/core/compiler/ssa_tape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@ pub struct SsaTape {

/// Number of choice operations in the tape
pub choice_count: usize,

/// Number of output operations in the tape
pub output_count: usize,
}

impl SsaTape {
/// Flattens a subtree of the graph into straight-line code.
///
/// This should always succeed unless the `root` is from a different
/// `Context`, in which case `Error::BadNode` will be returned.
pub fn new(ctx: &Context, root: Node) -> Result<(Self, VarMap), Error> {
pub fn new(ctx: &Context, roots: &[Node]) -> Result<(Self, VarMap), Error> {
let mut mapping = HashMap::new();
let mut parent_count: HashMap<Node, usize> = HashMap::new();
let mut slot_count = 0;
Expand All @@ -48,7 +51,7 @@ impl SsaTape {
// Accumulate parent counts and declare all nodes
let mut seen = HashSet::new();
let mut vars = VarMap::new();
let mut todo = vec![root];
let mut todo = roots.to_vec();
while let Some(node) = todo.pop() {
if !seen.insert(node) {
continue;
Expand Down Expand Up @@ -76,15 +79,18 @@ impl SsaTape {

// Now that we've populated our parents, flatten the graph
let mut seen = HashSet::new();
let mut todo = vec![root];
let mut todo = roots.to_vec();
let mut choice_count = 0;

let mut tape = vec![];
match mapping[&root] {
Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, 0)),
Slot::Immediate(imm) => {
tape.push(SsaOp::Output(0, 0));
tape.push(SsaOp::CopyImm(0, imm));
for (i, r) in roots.iter().enumerate() {
let i = i as u32;
match mapping[r] {
Slot::Reg(out_reg) => tape.push(SsaOp::Output(out_reg, i)),
Slot::Immediate(imm) => {
tape.push(SsaOp::Output(0, i));
tape.push(SsaOp::CopyImm(0, imm));
}
}
}

Expand Down Expand Up @@ -235,7 +241,14 @@ impl SsaTape {
tape.push(op);
}

Ok((SsaTape { tape, choice_count }, vars))
Ok((
SsaTape {
tape,
choice_count,
output_count: roots.len(),
},
vars,
))
}

/// Checks whether the tape is empty
Expand Down Expand Up @@ -410,7 +423,7 @@ mod test {
let c8 = ctx.sub(c7, r).unwrap();
let c9 = ctx.max(c8, c6).unwrap();

let (tape, vs) = SsaTape::new(&ctx, c9).unwrap();
let (tape, vs) = SsaTape::new(&ctx, &[c9]).unwrap();
assert_eq!(tape.len(), 9);
assert_eq!(vs.len(), 2);
}
Expand All @@ -421,7 +434,7 @@ mod test {
let x = ctx.x();
let x_squared = ctx.mul(x, x).unwrap();

let (tape, vs) = SsaTape::new(&ctx, x_squared).unwrap();
let (tape, vs) = SsaTape::new(&ctx, &[x_squared]).unwrap();
assert_eq!(tape.len(), 3); // x, square, output
assert_eq!(vs.len(), 1);
}
Expand All @@ -430,7 +443,7 @@ mod test {
fn test_constant() {
let mut ctx = Context::new();
let p = ctx.constant(1.5);
let (tape, vs) = SsaTape::new(&ctx, p).unwrap();
let (tape, vs) = SsaTape::new(&ctx, &[p]).unwrap();
assert_eq!(tape.len(), 2); // CopyImm, output
assert_eq!(vs.len(), 0);
}
Expand Down
4 changes: 2 additions & 2 deletions fidget/src/core/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,7 @@ mod test {
let c8 = ctx.sub(c7, r).unwrap();
let c9 = ctx.max(c8, c6).unwrap();

let tape = VmData::<255>::new(&ctx, c9).unwrap();
let tape = VmData::<255>::new(&ctx, &[c9]).unwrap();
assert_eq!(tape.len(), 9);
assert_eq!(tape.vars.len(), 2);
}
Expand All @@ -1577,7 +1577,7 @@ mod test {
let x = ctx.x();
let x_squared = ctx.mul(x, x).unwrap();

let tape = VmData::<255>::new(&ctx, x_squared).unwrap();
let tape = VmData::<255>::new(&ctx, &[x_squared]).unwrap();
assert_eq!(tape.len(), 3); // x, square, output
assert_eq!(tape.vars.len(), 1);
}
Expand Down
2 changes: 1 addition & 1 deletion fidget/src/core/eval/bulk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub trait BulkEvaluator: Default {
/// Container for bulk output results
///
/// This container represents an array-of-arrays. It is indexed first by
/// variable, then by index within the evaluation array.
/// output index, then by index within the evaluation array.
pub struct BulkOutput<'a, T> {
data: &'a Vec<Vec<T>>,
len: usize,
Expand Down
9 changes: 8 additions & 1 deletion fidget/src/core/eval/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ pub trait Tape {

/// Returns a mapping from [`Var`](crate::var::Var) to evaluation index
fn vars(&self) -> &VarMap;

/// Returns the number of outputs written by this tape
///
/// The order of outputs is set by the caller at tape construction, so we
/// don't need a map to determine the index of a particular output (unlike
/// variables).
fn output_count(&self) -> usize;
}

/// Represents the trace captured by a tracing evaluation
Expand Down Expand Up @@ -175,7 +182,7 @@ pub trait Function: Send + Sync + Clone {
/// A [`Function`] which can be built from a math expression
pub trait MathFunction: Function {
/// Builds a new function from the given context and node
fn new(ctx: &Context, node: Node) -> Result<Self, Error>
fn new(ctx: &Context, nodes: &[Node]) -> Result<Self, Error>
where
Self: Sized;
}
20 changes: 10 additions & 10 deletions fidget/src/core/eval/test/float_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
let x = ctx.x();
let x1 = ctx.add(x, 1.0).unwrap();

let shape_x = F::new(&ctx, x).unwrap();
let shape_x1 = F::new(&ctx, x1).unwrap();
let shape_x = F::new(&ctx, &[x]).unwrap();
let shape_x1 = F::new(&ctx, &[x1]).unwrap();

// This is a fuzz test for icache issues
let mut eval = F::new_float_slice_eval();
Expand Down Expand Up @@ -53,7 +53,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
let y = ctx.y();

let mut eval = F::new_float_slice_eval();
let shape = F::new(&ctx, x).unwrap();
let shape = F::new(&ctx, &[x]).unwrap();
let tape = shape.float_slice_tape(Default::default());
let out = eval
.eval(&tape, &[[0.0, 1.0, 2.0, 3.0].as_slice()])
Expand All @@ -77,7 +77,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
assert_eq!(&out[0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);

let mul = ctx.mul(y, 2.0).unwrap();
let shape = F::new(&ctx, mul).unwrap();
let shape = F::new(&ctx, &[mul]).unwrap();
let tape = shape.float_slice_tape(Default::default());
let out = eval
.eval(&tape, &[[3.0, 2.0, 1.0, 0.0].as_slice()])
Expand All @@ -98,7 +98,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
let a = ctx.x();
let b = ctx.sin(a).unwrap();

let shape = F::new(&ctx, b).unwrap();
let shape = F::new(&ctx, &[b]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());

Expand Down Expand Up @@ -164,7 +164,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
let z: Vec<f32> =
args[2..].iter().chain(&args[0..2]).cloned().collect();

let shape = F::new(&ctx, node).unwrap();
let shape = F::new(&ctx, &[node]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());

Expand Down Expand Up @@ -222,7 +222,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {

let node = C::build(&mut ctx, v);

let shape = F::new(&ctx, node).unwrap();
let shape = F::new(&ctx, &[node]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());

Expand Down Expand Up @@ -271,7 +271,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
rgsa.rotate_left(rot);
let node = C::build(&mut ctx, va, vb);

let shape = F::new(&ctx, node).unwrap();
let shape = F::new(&ctx, &[node]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());
let vars = tape.vars();
Expand Down Expand Up @@ -309,7 +309,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
for rhs in args.iter() {
let node = C::build(&mut ctx, va, *rhs);

let shape = F::new(&ctx, node).unwrap();
let shape = F::new(&ctx, &[node]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());

Expand Down Expand Up @@ -340,7 +340,7 @@ impl<F: Function + MathFunction> TestFloatSlice<F> {
for lhs in args.iter() {
let node = C::build(&mut ctx, *lhs, va);

let shape = F::new(&ctx, node).unwrap();
let shape = F::new(&ctx, &[node]).unwrap();
let mut eval = F::new_float_slice_eval();
let tape = shape.float_slice_tape(Default::default());

Expand Down
Loading

0 comments on commit 7b6c16a

Please sign in to comment.