Skip to content

Commit

Permalink
[pyschlandals] add to_graphviz + copy probability of computed circuit…
Browse files Browse the repository at this point in the history
… at PyDac creation
  • Loading branch information
AlexandreDubray committed Aug 11, 2023
1 parent f8d28da commit c0ed7cf
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
107 changes: 82 additions & 25 deletions pyschlandals/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use pyo3::prelude::*;
use pyo3::Python;
use std::path::PathBuf;
use schlandals::*;
use std::fs::File;
use std::io::Write;

#[pyclass]
#[derive(Clone)]
Expand Down Expand Up @@ -58,6 +60,7 @@ fn approximate_search_function(file: String, branching: BranchingHeuristic, epsi
fn compilation_submodule(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> {
let module = PyModule::new(py, "compiler")?;
module.add_function(wrap_pyfunction!(compile_function, module)?)?;
module.add_function(wrap_pyfunction!(compiler_from_file, module)?)?;
module.add_class::<PyDac>()?;
module.add_class::<PyCircuitNode>()?;
module.add_class::<PyDistributionNode>()?;
Expand Down Expand Up @@ -194,6 +197,78 @@ impl PyDac {
pub fn circuit_node_input_distribution_at(&self, node: usize, index: usize) -> (usize, usize) {
self.nodes[node].distribution_input[index]
}

pub fn to_graphviz(&self, path: String) {

let mut out = String::new();
out.push_str("digraph {\ntranksep = 3;\n\n");
let node_attributes = String::from("shape=circle,style=filled");

for node in 0..self.distributions.len() {
if !self.distributions[node].outputs.is_empty() {
out.push_str(&format!("{} [{},label=\"{}\"];\n", node, &node_attributes, &format!("d{}", node)));
}
}

for node in 0..self.nodes.len() {
let id = node + self.distributions.len();
if self.nodes[node].is_mul {
out.push_str(&format!("{} [{}, label=\"(X ({:.3}))\"];\n", id, &node_attributes, self.nodes[node].value))
} else {
out.push_str(&format!("{} [{}, label=\"(+ ({:.3}))\"];\n", id, &node_attributes, self.nodes[node].value))
}
}

for node in 0..self.distributions.len() {
for (output, value) in self.distributions[node].outputs.iter().copied() {
let to = output + self.distributions.len();
let label = format!("({}, {:.3})", value, self.distributions[node].probabilities[value]);
out.push_str(&format!("\t{node} -> {to} [penwidth=1,label=\"{label}\"];\n"));
}
}

for node in 0..self.nodes.len() {
let from = node + self.distributions.len();
let start = self.nodes[node].outputs_start;
let end = start + self.nodes[node].number_output;
for output in self.outputs[start..end].iter().copied() {
let to = output + self.distributions.len();
out.push_str(&format!("\t{from} -> {to} [penwidth=1];\n"));
}
}
out.push_str("}\n");
let mut outfile = File::create(path).unwrap();
match outfile.write(out.as_bytes()) {
Ok(_) => (),
Err(e) => println!("Culd not write the circuit into the dot file: {:?}", e),
}
}
}

fn pydac_from_dac(dac: Dac) -> PyDac {
let mut py_dac = PyDac::new();
py_dac.outputs = dac.outputs_node_iter().map(|n| n.0).collect();
py_dac.inputs = dac.inputs_node_iter().map(|n| n.0).collect();
for distribution in dac.distributions_iter() {
py_dac.distributions.push(PyDistributionNode {
probabilities: dac.get_distribution_probabilities(distribution).to_vec(),
outputs: dac.get_distribution_outputs(distribution).iter().copied().map(|(node, value)| (node.0, value)).collect(),
})
}

for node in dac.nodes_iter() {
let is_mul = dac.is_circuit_node_mul(node);
py_dac.nodes.push(PyCircuitNode {
outputs_start: dac.get_circuit_node_out_start(node),
number_output: dac.get_circuit_node_number_output(node),
inputs_start: dac.get_circuit_node_in_start(node),
number_input: dac.get_circuit_node_number_input(node),
distribution_input: dac.get_circuit_node_input_distribution(node).map(|(n, v)| (n.0, v)).collect(),
value: dac.get_circuit_node_probability(node).to_f64(),
is_mul,
})
}
py_dac
}

#[pyfunction]
Expand All @@ -206,34 +281,16 @@ fn compile_function(file: String, branching: BranchingHeuristic) -> Option<PyDac
};
match compile(PathBuf::from(file), branching_heuristic, None, None) {
None => None,
Some(dac) => {
let mut py_dac = PyDac::new();
py_dac.outputs = dac.outputs_node_iter().map(|n| n.0).collect();
py_dac.inputs = dac.inputs_node_iter().map(|n| n.0).collect();
for distribution in dac.distributions_iter() {
py_dac.distributions.push(PyDistributionNode {
probabilities: dac.get_distribution_probabilities(distribution).to_vec(),
outputs: dac.get_distribution_outputs(distribution).iter().copied().map(|(node, value)| (node.0, value)).collect(),
})
}

for node in dac.nodes_iter() {
let is_mul = dac.is_circuit_node_mul(node);
py_dac.nodes.push(PyCircuitNode {
outputs_start: dac.get_circuit_node_out_start(node),
number_output: dac.get_circuit_node_number_output(node),
inputs_start: dac.get_circuit_node_in_start(node),
number_input: dac.get_circuit_node_number_input(node),
distribution_input: dac.get_circuit_node_input_distribution(node).map(|(n, v)| (n.0, v)).collect(),
value: if is_mul { 1.0 } else { 0.0 },
is_mul,
})
}
Some(py_dac)
}
Some(dac) => Some(pydac_from_dac(dac)),
}
}

#[pyfunction]
#[pyo3(name = "dac_from_file")]
fn compiler_from_file(file: String) -> PyDac {
pydac_from_dac(Dac::from_file(&PathBuf::from(file)))
}


/// Base module for pyschlandals
#[pymodule]
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ impl Dac {
pub fn is_circuit_node_mul(&self, node: CircuitNodeIndex) -> bool {
self.nodes[node.0].is_mul
}

pub fn get_circuit_node_probability(&self, node: CircuitNodeIndex) -> &Float {
&self.nodes[node.0].value
}

}

Expand Down

0 comments on commit c0ed7cf

Please sign in to comment.